Explorar o código

zen: failover on error

Frank hai 3 meses
pai
achega
7d56603c26

+ 69 - 42
packages/console/app/src/routes/zen/util/handler.ts

@@ -20,6 +20,10 @@ import { oaCompatHelper } from "./provider/openai-compatible"
 import { createRateLimiter } from "./rateLimiter"
 
 type ZenData = Awaited<ReturnType<typeof ZenData.list>>
+type RetryOptions = {
+  excludeProviders: string[]
+  retryCount: number
+}
 
 export async function handler(
   input: APIEvent,
@@ -32,6 +36,7 @@ export async function handler(
   type ModelInfo = Awaited<ReturnType<typeof validateModel>>
   type ProviderInfo = Awaited<ReturnType<typeof selectProvider>>
 
+  const MAX_RETRIES = 3
   const FREE_WORKSPACES = [
     "wrk_01K46JDFR0E75SG2Q8K172KF3Y", // frank
     "wrk_01K6W1A3VE0KMNVSCQT43BG2SX", // opencode bench
@@ -47,40 +52,56 @@ export async function handler(
     })
     const zenData = ZenData.list()
     const modelInfo = validateModel(zenData, body.model)
-    const providerInfo = selectProvider(zenData, modelInfo, ip)
-    const authInfo = await authenticate(modelInfo, providerInfo)
     const rateLimiter = createRateLimiter(modelInfo.id, modelInfo.rateLimit, ip)
     await rateLimiter?.check()
-    validateBilling(authInfo, modelInfo)
-    validateModelSettings(authInfo)
-    updateProviderKey(authInfo, providerInfo)
-    logger.metric({ provider: providerInfo.id })
-
-    // Request to model provider
-    const startTimestamp = Date.now()
-    const reqUrl = providerInfo.modifyUrl(providerInfo.api)
-    const reqBody = JSON.stringify(
-      providerInfo.modifyBody({
-        ...createBodyConverter(opts.format, providerInfo.format)(body),
-        model: providerInfo.model,
-      }),
-    )
-    logger.debug("REQUEST URL: " + reqUrl)
-    logger.debug("REQUEST: " + reqBody.substring(0, 300) + "...")
-    const res = await fetch(reqUrl, {
-      method: "POST",
-      headers: (() => {
-        const headers = input.request.headers
-        headers.delete("host")
-        headers.delete("content-length")
-        providerInfo.modifyHeaders(headers, body, providerInfo.apiKey)
-        Object.entries(providerInfo.headerMappings ?? {}).forEach(([k, v]) => {
-          headers.set(k, headers.get(v)!)
+
+    const retriableRequest = async (retry: RetryOptions = { excludeProviders: [], retryCount: 0 }) => {
+      const providerInfo = selectProvider(zenData, modelInfo, ip, retry)
+      const authInfo = await authenticate(modelInfo, providerInfo)
+      validateBilling(authInfo, modelInfo)
+      validateModelSettings(authInfo)
+      updateProviderKey(authInfo, providerInfo)
+      logger.metric({ provider: providerInfo.id })
+
+      const startTimestamp = Date.now()
+      const reqUrl = providerInfo.modifyUrl(providerInfo.api)
+      const reqBody = JSON.stringify(
+        providerInfo.modifyBody({
+          ...createBodyConverter(opts.format, providerInfo.format)(body),
+          model: providerInfo.model,
+        }),
+      )
+      logger.debug("REQUEST URL: " + reqUrl)
+      logger.debug("REQUEST: " + reqBody.substring(0, 300) + "...")
+      const res = await fetch(reqUrl, {
+        method: "POST",
+        headers: (() => {
+          const headers = new Headers(input.request.headers)
+          providerInfo.modifyHeaders(headers, body, providerInfo.apiKey)
+          Object.entries(providerInfo.headerMappings ?? {}).forEach(([k, v]) => {
+            headers.set(k, headers.get(v)!)
+          })
+          headers.delete("host")
+          headers.delete("content-length")
+          headers.delete("x-opencode-request")
+          headers.delete("x-opencode-session")
+          return headers
+        })(),
+        body: reqBody,
+      })
+
+      // Try another provider => stop retrying if using fallback provider
+      if (res.status !== 200 && modelInfo.fallbackProvider && providerInfo.id !== modelInfo.fallbackProvider) {
+        return retriableRequest({
+          excludeProviders: [...retry.excludeProviders, providerInfo.id],
+          retryCount: retry.retryCount + 1,
         })
-        return headers
-      })(),
-      body: reqBody,
-    })
+      }
+
+      return { providerInfo, authInfo, res, startTimestamp }
+    }
+
+    const { providerInfo, authInfo, res, startTimestamp } = await retriableRequest()
 
     // Scrub response headers
     const resHeaders = new Headers()
@@ -236,19 +257,25 @@ export async function handler(
     return { id: modelId, ...modelData }
   }
 
-  function selectProvider(zenData: ZenData, modelInfo: ModelInfo, ip: string) {
-    const providers = modelInfo.providers
-      .filter((provider) => !provider.disabled)
-      .flatMap((provider) => Array<typeof provider>(provider.weight ?? 1).fill(provider))
+  function selectProvider(zenData: ZenData, modelInfo: ModelInfo, ip: string, retry: RetryOptions) {
+    const provider = (() => {
+      if (retry.retryCount === MAX_RETRIES) {
+        return modelInfo.providers.find((provider) => provider.id === modelInfo.fallbackProvider)
+      }
 
-    // Use the last 2 characters of IP address to select a provider
-    const lastChars = ip.slice(-2)
-    const index = parseInt(lastChars, 16) % providers.length
-    const provider = providers[index || 0]
+      const providers = modelInfo.providers
+        .filter((provider) => !provider.disabled)
+        .filter((provider) => !retry.excludeProviders.includes(provider.id))
+        .flatMap((provider) => Array<typeof provider>(provider.weight ?? 1).fill(provider))
 
-    if (!(provider.id in zenData.providers)) {
-      throw new ModelError(`Provider ${provider.id} not supported`)
-    }
+      // Use the last 2 characters of IP address to select a provider
+      const lastChars = ip.slice(-2)
+      const index = parseInt(lastChars, 16) % providers.length
+      return providers[index || 0]
+    })()
+
+    if (!provider) throw new ModelError("No provider available")
+    if (!(provider.id in zenData.providers)) throw new ModelError(`Provider ${provider.id} not supported`)
 
     return {
       ...provider,

+ 1 - 0
packages/console/core/src/model.ts

@@ -25,6 +25,7 @@ export namespace ZenData {
     cost200K: ModelCostSchema.optional(),
     allowAnonymous: z.boolean().optional(),
     rateLimit: z.number().optional(),
+    fallbackProvider: z.string().optional(),
     providers: z.array(
       z.object({
         id: z.string(),