Frank 2 месяцев назад
Родитель
Сommit
4380727727

+ 26 - 12
packages/console/app/src/routes/zen/util/handler.ts

@@ -73,8 +73,16 @@ export async function handler(
     const stickyProvider = await stickyTracker?.get()
 
     const retriableRequest = async (retry: RetryOptions = { excludeProviders: [], retryCount: 0 }) => {
-      const providerInfo = selectProvider(zenData, modelInfo, sessionId, isTrial ?? false, retry, stickyProvider)
-      const authInfo = await authenticate(modelInfo, providerInfo)
+      const authInfo = await authenticate(modelInfo)
+      const providerInfo = selectProvider(
+        zenData,
+        authInfo,
+        modelInfo,
+        sessionId,
+        isTrial ?? false,
+        retry,
+        stickyProvider,
+      )
       validateBilling(authInfo, modelInfo)
       validateModelSettings(authInfo)
       updateProviderKey(authInfo, providerInfo)
@@ -291,6 +299,7 @@ export async function handler(
 
   function selectProvider(
     zenData: ZenData,
+    authInfo: AuthInfo,
     modelInfo: ModelInfo,
     sessionId: string,
     isTrial: boolean,
@@ -298,6 +307,10 @@ export async function handler(
     stickyProvider: string | undefined,
   ) {
     const provider = (() => {
+      if (authInfo?.provider?.credentials) {
+        return modelInfo.providers.find((provider) => provider.id === modelInfo.byokProvider)
+      }
+
       if (isTrial) {
         return modelInfo.providers.find((provider) => provider.id === modelInfo.trial!.provider)
       }
@@ -342,15 +355,15 @@ export async function handler(
     }
   }
 
-  async function authenticate(modelInfo: ModelInfo, providerInfo: ProviderInfo) {
+  async function authenticate(modelInfo: ModelInfo) {
     const apiKey = opts.parseApiKey(input.request.headers)
     if (!apiKey || apiKey === "public") {
       if (modelInfo.allowAnonymous) return
       throw new AuthError("Missing API key.")
     }
 
-    const data = await Database.use((tx) =>
-      tx
+    const data = await Database.use((tx) => {
+      const query = tx
         .select({
           apiKey: KeyTable.id,
           workspaceID: KeyTable.workspaceID,
@@ -378,13 +391,15 @@ export async function handler(
         .innerJoin(BillingTable, eq(BillingTable.workspaceID, KeyTable.workspaceID))
         .innerJoin(UserTable, and(eq(UserTable.workspaceID, KeyTable.workspaceID), eq(UserTable.id, KeyTable.userID)))
         .leftJoin(ModelTable, and(eq(ModelTable.workspaceID, KeyTable.workspaceID), eq(ModelTable.model, modelInfo.id)))
-        .leftJoin(
+
+      if (modelInfo.byokProvider) {
+        query.leftJoin(
           ProviderTable,
-          and(eq(ProviderTable.workspaceID, KeyTable.workspaceID), eq(ProviderTable.provider, providerInfo.id)),
+          and(eq(ProviderTable.workspaceID, KeyTable.workspaceID), eq(ProviderTable.provider, modelInfo.byokProvider)),
         )
-        .where(and(eq(KeyTable.key, apiKey), isNull(KeyTable.timeDeleted)))
-        .then((rows) => rows[0]),
-    )
+      }
+      return query.where(and(eq(KeyTable.key, apiKey), isNull(KeyTable.timeDeleted))).then((rows) => rows[0])
+    })
 
     if (!data) throw new AuthError("Invalid API key.")
     logger.metric({
@@ -457,8 +472,7 @@ export async function handler(
   }
 
   function updateProviderKey(authInfo: AuthInfo, providerInfo: ProviderInfo) {
-    if (!authInfo) return
-    if (!authInfo.provider?.credentials) return
+    if (!authInfo?.provider?.credentials) return
     providerInfo.apiKey = authInfo.provider.credentials
   }
 

+ 1 - 0
packages/console/core/src/model.ts

@@ -24,6 +24,7 @@ export namespace ZenData {
     cost: ModelCostSchema,
     cost200K: ModelCostSchema.optional(),
     allowAnonymous: z.boolean().optional(),
+    byokProvider: z.enum(["openai", "anthropic", "google"]).optional(),
     stickyProvider: z.boolean().optional(),
     trial: z
       .object({