Frank 2 weeks ago
parent
commit
801e4a8a9d

+ 1 - 1
packages/console/app/src/routes/zen/util/handler.ts

@@ -79,7 +79,7 @@ export async function handler(
     const dataDumper = createDataDumper(sessionId, requestId, projectId)
     const trialLimiter = createTrialLimiter(modelInfo.trial, ip, ocClient)
     const isTrial = await trialLimiter?.isTrial()
-    const rateLimiter = createRateLimiter(modelInfo.rateLimit, ip)
+    const rateLimiter = createRateLimiter(modelInfo.rateLimit, ip, input.request.headers)
     await rateLimiter?.check()
     const stickyTracker = createStickyTracker(modelInfo.stickyProvider, sessionId)
     const stickyProvider = await stickyTracker?.get()

+ 6 - 2
packages/console/app/src/routes/zen/util/rateLimiter.ts

@@ -4,9 +4,13 @@ import { RateLimitError } from "./error"
 import { logger } from "./logger"
 import { ZenData } from "@opencode-ai/console-core/model.js"
 
-export function createRateLimiter(limit: ZenData.RateLimit | undefined, rawIp: string) {
+export function createRateLimiter(limit: ZenData.RateLimit | undefined, rawIp: string, headers: Headers) {
   if (!limit) return
 
+  const limitValue = (limit.checkHeader && !headers.get(limit.checkHeader))
+    ? limit.fallbackValue!
+    : limit.value
+
   const ip = !rawIp.length ? "unknown" : rawIp
   const now = Date.now()
   const intervals =
@@ -32,7 +36,7 @@ export function createRateLimiter(limit: ZenData.RateLimit | undefined, rawIp: s
       )
       const total = rows.reduce((sum, r) => sum + r.count, 0)
       logger.debug(`rate limit total: ${total}`)
-      if (total >= limit.value) throw new RateLimitError(`Rate limit exceeded. Please try again later.`)
+      if (total >= limitValue) throw new RateLimitError(`Rate limit exceeded. Please try again later.`)
     },
   }
 }

+ 2 - 0
packages/console/core/src/model.ts

@@ -21,6 +21,8 @@ export namespace ZenData {
   const RateLimitSchema = z.object({
     period: z.enum(["day", "rolling"]),
     value: z.number().int(),
+    checkHeader: z.string().optional(),
+    fallbackValue: z.number().int().optional(),
   })
   export type Format = z.infer<typeof FormatSchema>
   export type Trial = z.infer<typeof TrialSchema>