Frank 3 дней назад
Родитель
Сommit
3b2a2c461d

+ 2 - 2
packages/console/app/src/routes/zen/util/handler.ts

@@ -106,7 +106,7 @@ export async function handler(
     const zenData = ZenData.list(opts.modelList)
     const modelInfo = validateModel(zenData, model)
     const dataDumper = createDataDumper(sessionId, requestId, projectId)
-    const trialLimiter = createTrialLimiter(modelInfo.trialProviders, ip)
+    const trialLimiter = createTrialLimiter(modelInfo.trialProvider, ip)
     const trialProviders = await trialLimiter?.check()
     const rateLimiter = createRateLimiter(
       modelInfo.id,
@@ -392,7 +392,7 @@ export async function handler(
   function validateModel(zenData: ZenData, reqModel: string) {
     if (!(reqModel in zenData.models)) throw new ModelError(t("zen.api.error.modelNotSupported", { model: reqModel }))
 
-    const modelId = reqModel as keyof typeof zenData.models
+    const modelId = reqModel
     const modelData = Array.isArray(zenData.models[modelId])
       ? zenData.models[modelId].find((model) => opts.format === model.formatFilter)
       : zenData.models[modelId]

+ 65 - 6
packages/console/core/src/model.ts

@@ -26,7 +26,7 @@ export namespace ZenData {
     allowAnonymous: z.boolean().optional(),
     byokProvider: z.enum(["openai", "anthropic", "google"]).optional(),
     stickyProvider: z.enum(["strict", "prefer"]).optional(),
-    trialProviders: z.array(z.string()).optional(),
+    trialProvider: z.string().optional(),
     trialEnded: z.boolean().optional(),
     fallbackProvider: z.string().optional(),
     rateLimit: z.number().optional(),
@@ -45,7 +45,7 @@ export namespace ZenData {
 
   const ProviderSchema = z.object({
     api: z.string(),
-    apiKey: z.string(),
+    apiKey: z.union([z.string(), z.record(z.string(), z.string())]),
     format: FormatSchema.optional(),
     headerMappings: z.record(z.string(), z.string()).optional(),
     payloadModifier: z.record(z.string(), z.any()).optional(),
@@ -54,7 +54,10 @@ export namespace ZenData {
   })
 
   const ModelsSchema = z.object({
-    models: z.record(z.string(), z.union([ModelSchema, z.array(ModelSchema.extend({ formatFilter: FormatSchema }))])),
+    zenModels: z.record(
+      z.string(),
+      z.union([ModelSchema, z.array(ModelSchema.extend({ formatFilter: FormatSchema }))]),
+    ),
     liteModels: z.record(
       z.string(),
       z.union([ModelSchema, z.array(ModelSchema.extend({ formatFilter: FormatSchema }))]),
@@ -99,10 +102,66 @@ export namespace ZenData {
         Resource.ZEN_MODELS29.value +
         Resource.ZEN_MODELS30.value,
     )
-    const { models, liteModels, providers } = ModelsSchema.parse(json)
+    const { zenModels, liteModels, providers } = ModelsSchema.parse(json)
+    const compositeProviders = Object.fromEntries(
+      Object.entries(providers).map(([id, provider]) => [
+        id,
+        typeof provider.apiKey === "string"
+          ? [{ id: id, key: provider.apiKey }]
+          : Object.entries(provider.apiKey).map(([kid, key]) => ({
+              id: `${id}.${kid}`,
+              key,
+            })),
+      ]),
+    )
     return {
-      models: modelList === "lite" ? liteModels : models,
-      providers,
+      providers: Object.fromEntries(
+        Object.entries(providers).flatMap(([providerId, provider]) =>
+          compositeProviders[providerId].map((p) => [p.id, { ...provider, apiKey: p.key }]),
+        ),
+      ),
+      models: (() => {
+        const normalize = (model: z.infer<typeof ModelSchema>) => {
+          const composite = model.providers.find((p) => compositeProviders[p.id].length > 1)
+          if (!composite)
+            return {
+              trialProvider: model.trialProvider ? [model.trialProvider] : undefined,
+            }
+
+          const weightMulti = compositeProviders[composite.id].length
+
+          return {
+            trialProvider: (() => {
+              if (!model.trialProvider) return undefined
+              if (model.trialProvider === composite.id) return compositeProviders[composite.id].map((p) => p.id)
+              return [model.trialProvider]
+            })(),
+            providers: model.providers.flatMap((p) =>
+              p.id === composite.id
+                ? compositeProviders[p.id].map((sub) => ({
+                    ...p,
+                    id: sub.id,
+                    weight: p.weight ?? 1,
+                  }))
+                : [
+                    {
+                      ...p,
+                      weight: (p.weight ?? 1) * weightMulti,
+                    },
+                  ],
+            ),
+          }
+        }
+
+        return Object.fromEntries(
+          Object.entries(modelList === "lite" ? liteModels : zenModels).map(([modelId, model]) => {
+            const n = Array.isArray(model)
+              ? model.map((m) => ({ ...m, ...normalize(m) }))
+              : { ...model, ...normalize(model) }
+            return [modelId, n]
+          }),
+        )
+      })(),
     }
   })
 }