Przeglądaj źródła

fix: finalize aborted proxy streams

jill 3 miesięcy temu
rodzic
commit
18b2ce20a1
1 zmienionych plików z 114 dodań i 46 usunięć
  1. 114 46
      src/app/v1/_lib/proxy/response-handler.ts

+ 114 - 46
src/app/v1/_lib/proxy/response-handler.ts

@@ -17,6 +17,7 @@ import type { Format, TransformState } from "../converters/types";
 import { mapClientFormatToTransformer, mapProviderTypeToTransformer } from "./format-mapper";
 import { AsyncTaskManager } from "@/lib/async-task-manager";
 import { isClientAbortError } from "./errors";
+import type { SessionUsageUpdate } from "@/types/session";
 
 export type UsageMetrics = {
   input_tokens?: number;
@@ -100,6 +101,32 @@ export class ProxyResponseHandler {
     const abortController = new AbortController();
 
     const processingPromise = (async () => {
+      const finalizeNonStreamAbort = async (): Promise<void> => {
+        if (messageContext) {
+          const duration = Date.now() - session.startTime;
+          await updateMessageRequestDuration(messageContext.id, duration);
+          await updateMessageRequestDetails(messageContext.id, {
+            statusCode: statusCode,
+            providerChain: session.getProviderChain(),
+          });
+          const tracker = ProxyStatusTracker.getInstance();
+          tracker.endRequest(messageContext.user.id, messageContext.id);
+        }
+
+        if (session.sessionId) {
+          const sessionUsagePayload: SessionUsageUpdate = {
+            status: statusCode >= 200 && statusCode < 300 ? "completed" : "error",
+            statusCode: statusCode,
+          };
+
+          void SessionManager.updateSessionUsage(session.sessionId, sessionUsagePayload).catch(
+            (error: unknown) => {
+              logger.error("[ResponseHandler] Failed to update session usage:", error);
+            }
+          );
+        }
+      };
+
       try {
         // ✅ 检查客户端是否断开
         if (session.clientAbortSignal?.aborted || abortController.signal.aborted) {
@@ -107,6 +134,15 @@ export class ProxyResponseHandler {
             taskId,
             providerId: provider.id,
           });
+          try {
+            await finalizeNonStreamAbort();
+          } catch (finalizeError) {
+            logger.error("ResponseHandler: Failed to finalize aborted non-stream response", {
+              taskId,
+              providerId: provider.id,
+              finalizeError,
+            });
+          }
           return;
         }
 
@@ -208,6 +244,15 @@ export class ProxyResponseHandler {
                 ? "Response transmission interrupted"
                 : "Client disconnected",
           });
+          try {
+            await finalizeNonStreamAbort();
+          } catch (finalizeError) {
+            logger.error("ResponseHandler: Failed to finalize aborted non-stream response", {
+              taskId,
+              providerId: provider.id,
+              finalizeError,
+            });
+          }
         } else {
           logger.error("Failed to handle non-stream log:", error);
         }
@@ -307,36 +352,16 @@ export class ProxyResponseHandler {
       const reader = internalStream.getReader();
       const decoder = new TextDecoder();
       const chunks: string[] = [];
-      let usageForCost: UsageMetrics | null = null;
-
-      try {
-        while (true) {
-          // ✅ 检查取消信号
-          if (session.clientAbortSignal?.aborted || abortController.signal.aborted) {
-            logger.info("ResponseHandler: Stream processing cancelled", {
-              taskId,
-              providerId: provider.id,
-              chunksCollected: chunks.length,
-            });
-            break; // 提前终止
-          }
-
-          const { value, done } = await reader.read();
-          if (done) {
-            break;
-          }
-          if (value) {
-            chunks.push(decoder.decode(value, { stream: true }));
-          }
-        }
 
+      const flushAndJoin = (): string => {
         const flushed = decoder.decode();
         if (flushed) {
           chunks.push(flushed);
         }
+        return chunks.join("");
+      };
 
-        const allContent = chunks.join("");
-
+      const finalizeStream = async (allContent: string): Promise<void> => {
         // 存储响应体到 Redis(5分钟过期)
         if (session.sessionId) {
           void SessionManager.storeSessionResponse(session.sessionId, allContent).catch((err) => {
@@ -347,33 +372,31 @@ export class ProxyResponseHandler {
         const duration = Date.now() - session.startTime;
         await updateMessageRequestDuration(messageContext.id, duration);
 
-        // 记录请求结束
         const tracker = ProxyStatusTracker.getInstance();
         tracker.endRequest(messageContext.user.id, messageContext.id);
 
         const usageResult = parseUsageFromResponseText(allContent, provider.providerType);
-        usageForCost = usageResult.usageMetrics;
+        const usageMetrics = usageResult.usageMetrics;
 
         await updateRequestCostFromUsage(
           messageContext.id,
           session.getOriginalModel(),
           session.getCurrentModel(),
-          usageForCost,
+          usageMetrics,
           provider.costMultiplier
         );
 
         // 追踪消费到 Redis(用于限流)
-        await trackCostToRedis(session, usageForCost);
+        await trackCostToRedis(session, usageMetrics);
 
         // 更新 session 使用量到 Redis(用于实时监控)
-        if (session.sessionId && usageForCost) {
-          // 计算成本(复用相同逻辑)
+        if (session.sessionId) {
           let costUsdStr: string | undefined;
-          if (session.request.model) {
+          if (usageMetrics && session.request.model) {
             const priceData = await findLatestPriceByModel(session.request.model);
             if (priceData?.priceData) {
               const cost = calculateRequestCost(
-                usageForCost,
+                usageMetrics,
                 priceData.priceData,
                 provider.costMultiplier
               );
@@ -383,28 +406,63 @@ export class ProxyResponseHandler {
             }
           }
 
-          void SessionManager.updateSessionUsage(session.sessionId, {
-            inputTokens: usageForCost.input_tokens,
-            outputTokens: usageForCost.output_tokens,
-            cacheCreationInputTokens: usageForCost.cache_creation_input_tokens,
-            cacheReadInputTokens: usageForCost.cache_read_input_tokens,
-            costUsd: costUsdStr,
+          const sessionUsagePayload: SessionUsageUpdate = {
             status: statusCode >= 200 && statusCode < 300 ? "completed" : "error",
-            statusCode: statusCode,
-          }).catch((error: unknown) => {
-            logger.error("[ResponseHandler] Failed to update session usage:", error);
-          });
+            statusCode,
+          };
+
+          if (usageMetrics) {
+            sessionUsagePayload.inputTokens = usageMetrics.input_tokens;
+            sessionUsagePayload.outputTokens = usageMetrics.output_tokens;
+            sessionUsagePayload.cacheCreationInputTokens = usageMetrics.cache_creation_input_tokens;
+            sessionUsagePayload.cacheReadInputTokens = usageMetrics.cache_read_input_tokens;
+          }
+
+          if (costUsdStr) {
+            sessionUsagePayload.costUsd = costUsdStr;
+          }
+
+          void SessionManager.updateSessionUsage(session.sessionId, sessionUsagePayload).catch(
+            (error: unknown) => {
+              logger.error("[ResponseHandler] Failed to update session usage:", error);
+            }
+          );
         }
 
         // 保存扩展信息(status code, tokens, provider chain)
         await updateMessageRequestDetails(messageContext.id, {
           statusCode: statusCode,
-          inputTokens: usageForCost?.input_tokens,
-          outputTokens: usageForCost?.output_tokens,
-          cacheCreationInputTokens: usageForCost?.cache_creation_input_tokens,
-          cacheReadInputTokens: usageForCost?.cache_read_input_tokens,
+          inputTokens: usageMetrics?.input_tokens,
+          outputTokens: usageMetrics?.output_tokens,
+          cacheCreationInputTokens: usageMetrics?.cache_creation_input_tokens,
+          cacheReadInputTokens: usageMetrics?.cache_read_input_tokens,
           providerChain: session.getProviderChain(),
         });
+      };
+
+      try {
+        while (true) {
+          // ✅ 检查取消信号
+          if (session.clientAbortSignal?.aborted || abortController.signal.aborted) {
+            logger.info("ResponseHandler: Stream processing cancelled", {
+              taskId,
+              providerId: provider.id,
+              chunksCollected: chunks.length,
+            });
+            break; // 提前终止
+          }
+
+          const { value, done } = await reader.read();
+          if (done) {
+            break;
+          }
+          if (value) {
+            chunks.push(decoder.decode(value, { stream: true }));
+          }
+        }
+
+        const allContent = flushAndJoin();
+        await finalizeStream(allContent);
       } catch (error) {
         // 检测是否为客户端中断(使用统一的精确检测函数)
         const err = error as Error;
@@ -421,6 +479,16 @@ export class ProxyResponseHandler {
                 ? "Response transmission interrupted"
                 : "Client disconnected",
           });
+          try {
+            const allContent = flushAndJoin();
+            await finalizeStream(allContent);
+          } catch (finalizeError) {
+            logger.error("ResponseHandler: Failed to finalize aborted stream response", {
+              taskId,
+              messageId: messageContext.id,
+              finalizeError,
+            });
+          }
         } else {
           logger.error("Failed to save SSE content:", error);
         }