Browse Source

zen: return cost

Frank 1 week ago
parent
commit
d86f24b6b3

+ 74 - 44
packages/console/app/src/routes/zen/util/handler.ts

@@ -38,6 +38,7 @@ type RetryOptions = {
   excludeProviders: string[]
   retryCount: number
 }
+type BillingSource = "anonymous" | "free" | "byok" | "subscription" | "balance"
 
 export async function handler(
   input: APIEvent,
@@ -51,6 +52,7 @@ export async function handler(
   type AuthInfo = Awaited<ReturnType<typeof authenticate>>
   type ModelInfo = Awaited<ReturnType<typeof validateModel>>
   type ProviderInfo = Awaited<ReturnType<typeof selectProvider>>
+  type CostInfo = ReturnType<typeof calculateCost>
 
   const MAX_FAILOVER_RETRIES = 3
   const MAX_429_RETRIES = 3
@@ -139,21 +141,22 @@ export async function handler(
           "llm.error.code": res.status,
           "llm.error.message": res.statusText,
         })
+      }
 
-        // Try another provider => stop retrying if using fallback provider
-        if (
-          // ie. openai 404 error: Item with id 'msg_0ead8b004a3b165d0069436a6b6834819896da85b63b196a3f' not found.
-          res.status !== 404 &&
-          // ie. cannot change codex model providers mid-session
-          modelInfo.stickyProvider !== "strict" &&
-          modelInfo.fallbackProvider &&
-          providerInfo.id !== modelInfo.fallbackProvider
-        ) {
-          return retriableRequest({
-            excludeProviders: [...retry.excludeProviders, providerInfo.id],
-            retryCount: retry.retryCount + 1,
-          })
-        }
+      // Try another provider => stop retrying if using fallback provider
+      if (
+        res.status !== 200 &&
+        // ie. openai 404 error: Item with id 'msg_0ead8b004a3b165d0069436a6b6834819896da85b63b196a3f' not found.
+        res.status !== 404 &&
+        // ie. cannot change codex model providers mid-session
+        modelInfo.stickyProvider !== "strict" &&
+        modelInfo.fallbackProvider &&
+        providerInfo.id !== modelInfo.fallbackProvider
+      ) {
+        return retriableRequest({
+          excludeProviders: [...retry.excludeProviders, providerInfo.id],
+          retryCount: retry.retryCount + 1,
+        })
       }
 
       return { providerInfo, reqBody, res, startTimestamp }
@@ -183,18 +186,25 @@ export async function handler(
 
     // Handle non-streaming response
     if (!isStream) {
-      const responseConverter = createResponseConverter(providerInfo.format, opts.format)
       const json = await res.json()
-      const body = JSON.stringify(responseConverter(json))
+      const usageInfo = providerInfo.normalizeUsage(json.usage)
+      const costInfo = calculateCost(modelInfo, usageInfo)
+      await trialLimiter?.track(usageInfo)
+      await rateLimiter?.track()
+      await trackUsage(billingSource, authInfo, modelInfo, providerInfo, usageInfo, costInfo)
+      await reload(billingSource, authInfo, costInfo)
+
+      const responseConverter = createResponseConverter(providerInfo.format, opts.format)
+      const body = JSON.stringify(
+        responseConverter({
+          ...json,
+          cost: calculateOccuredCost(billingSource, costInfo),
+        }),
+      )
       logger.metric({ response_length: body.length })
       logger.debug("RESPONSE: " + body)
       dataDumper?.provideResponse(body)
       dataDumper?.flush()
-      const tokensInfo = providerInfo.normalizeUsage(json.usage)
-      await trialLimiter?.track(tokensInfo)
-      await rateLimiter?.track()
-      const costInfo = await trackUsage(authInfo, modelInfo, providerInfo, billingSource, tokensInfo)
-      await reload(authInfo, costInfo)
       return new Response(body, {
         status: resStatus,
         statusText: res.statusText,
@@ -226,12 +236,16 @@ export async function handler(
                 dataDumper?.flush()
                 await rateLimiter?.track()
                 const usage = usageParser.retrieve()
+                let cost = "0"
                 if (usage) {
-                  const tokensInfo = providerInfo.normalizeUsage(usage)
-                  await trialLimiter?.track(tokensInfo)
-                  const costInfo = await trackUsage(authInfo, modelInfo, providerInfo, billingSource, tokensInfo)
-                  await reload(authInfo, costInfo)
+                  const usageInfo = providerInfo.normalizeUsage(usage)
+                  const costInfo = calculateCost(modelInfo, usageInfo)
+                  await trialLimiter?.track(usageInfo)
+                  await trackUsage(billingSource, authInfo, modelInfo, providerInfo, usageInfo, costInfo)
+                  await reload(billingSource, authInfo, costInfo)
+                  cost = calculateOccuredCost(billingSource, costInfo)
                 }
+                c.enqueue(encoder.encode(usageParser.buidlCostChunk(cost)))
                 c.close()
                 return
               }
@@ -283,7 +297,6 @@ export async function handler(
         return pump()
       },
     })
-
     return new Response(stream, {
       status: resStatus,
       statusText: res.statusText,
@@ -498,9 +511,9 @@ export async function handler(
     }
   }
 
-  function validateBilling(authInfo: AuthInfo, modelInfo: ModelInfo) {
+  function validateBilling(authInfo: AuthInfo, modelInfo: ModelInfo): BillingSource {
     if (!authInfo) return "anonymous"
-    if (authInfo.provider?.credentials) return "free"
+    if (authInfo.provider?.credentials) return "byok"
     if (authInfo.isFree) return "free"
     if (modelInfo.allowAnonymous) return "free"
 
@@ -613,13 +626,7 @@ export async function handler(
     return res
   }
 
-  async function trackUsage(
-    authInfo: AuthInfo,
-    modelInfo: ModelInfo,
-    providerInfo: ProviderInfo,
-    billingSource: ReturnType<typeof validateBilling>,
-    usageInfo: UsageInfo,
-  ) {
+  function calculateCost(modelInfo: ModelInfo, usageInfo: UsageInfo) {
     const { inputTokens, outputTokens, reasoningTokens, cacheReadTokens, cacheWrite5mTokens, cacheWrite1hTokens } =
       usageInfo
 
@@ -657,6 +664,33 @@ export async function handler(
       (cacheReadCost ?? 0) +
       (cacheWrite5mCost ?? 0) +
       (cacheWrite1hCost ?? 0)
+    return {
+      totalCostInCent,
+      inputCost,
+      outputCost,
+      reasoningCost,
+      cacheReadCost,
+      cacheWrite5mCost,
+      cacheWrite1hCost,
+    }
+  }
+
+  function calculateOccuredCost(billingSource: BillingSource, costInfo: CostInfo) {
+    return billingSource === "balance" ? (costInfo.totalCostInCent / 100).toFixed(8) : "0"
+  }
+
+  async function trackUsage(
+    billingSource: BillingSource,
+    authInfo: AuthInfo,
+    modelInfo: ModelInfo,
+    providerInfo: ProviderInfo,
+    usageInfo: UsageInfo,
+    costInfo: CostInfo,
+  ) {
+    const { inputTokens, outputTokens, reasoningTokens, cacheReadTokens, cacheWrite5mTokens, cacheWrite1hTokens } =
+      usageInfo
+    const { totalCostInCent, inputCost, outputCost, reasoningCost, cacheReadCost, cacheWrite5mCost, cacheWrite1hCost } =
+      costInfo
 
     logger.metric({
       "tokens.input": inputTokens,
@@ -677,7 +711,7 @@ export async function handler(
     if (billingSource === "anonymous") return
     authInfo = authInfo!
 
-    const cost = authInfo.provider?.credentials ? 0 : centsToMicroCents(totalCostInCent)
+    const cost = centsToMicroCents(totalCostInCent)
     await Database.use((db) =>
       Promise.all([
         db.insert(UsageTable).values({
@@ -772,16 +806,12 @@ export async function handler(
     return { costInMicroCents: cost }
   }
 
-  async function reload(authInfo: AuthInfo, costInfo: Awaited<ReturnType<typeof trackUsage>>) {
-    if (!authInfo) return
-    if (authInfo.isFree) return
-    if (authInfo.provider?.credentials) return
-    if (authInfo.subscription) return
-
-    if (!costInfo) return
+  async function reload(billingSource: BillingSource, authInfo: AuthInfo, costInfo: CostInfo) {
+    if (billingSource !== "balance") return
+    authInfo = authInfo!
 
     const reloadTrigger = centsToMicroCents((authInfo.billing.reloadTrigger ?? Billing.RELOAD_TRIGGER) * 100)
-    if (authInfo.billing.balance - costInfo.costInMicroCents >= reloadTrigger) return
+    if (authInfo.billing.balance - costInfo.totalCostInCent >= reloadTrigger) return
     if (authInfo.billing.timeReloadLockedTill && authInfo.billing.timeReloadLockedTill > new Date()) return
 
     const lock = await Database.use((tx) =>

+ 1 - 0
packages/console/app/src/routes/zen/util/provider/anthropic.ts

@@ -167,6 +167,7 @@ export const anthropicHelper: ProviderHelper = ({ reqModel, providerModel }) =>
           }
         },
         retrieve: () => usage,
+        buidlCostChunk: (cost: string) => `event: ping\ndata: ${JSON.stringify({ type: "ping", cost })}\n\n`,
       }
     },
     normalizeUsage: (usage: Usage) => ({

+ 1 - 0
packages/console/app/src/routes/zen/util/provider/google.ts

@@ -56,6 +56,7 @@ export const googleHelper: ProviderHelper = ({ providerModel }) => ({
         usage = json.usageMetadata
       },
       retrieve: () => usage,
+      buidlCostChunk: (cost: string) => `data: ${JSON.stringify({ type: "ping", cost })}\n\n`,
     }
   },
   normalizeUsage: (usage: Usage) => {

+ 1 - 0
packages/console/app/src/routes/zen/util/provider/openai-compatible.ts

@@ -54,6 +54,7 @@ export const oaCompatHelper: ProviderHelper = () => ({
         usage = json.usage
       },
       retrieve: () => usage,
+      buidlCostChunk: (cost: string) => `data: ${JSON.stringify({ choices: [], cost })}\n\n`,
     }
   },
   normalizeUsage: (usage: Usage) => {

+ 1 - 0
packages/console/app/src/routes/zen/util/provider/openai.ts

@@ -43,6 +43,7 @@ export const openaiHelper: ProviderHelper = () => ({
         usage = json.response.usage
       },
       retrieve: () => usage,
+      buidlCostChunk: (cost: string) => `event: ping\ndata: ${JSON.stringify({ type: "ping", cost })}\n\n`,
     }
   },
   normalizeUsage: (usage: Usage) => {

+ 1 - 0
packages/console/app/src/routes/zen/util/provider/provider.ts

@@ -43,6 +43,7 @@ export type ProviderHelper = (input: { reqModel: string; providerModel: string }
   createUsageParser: () => {
     parse: (chunk: string) => void
     retrieve: () => any
+    buidlCostChunk: (cost: string) => string
   }
   normalizeUsage: (usage: any) => UsageInfo
 }