|
|
@@ -17,6 +17,7 @@ import type { Format, TransformState } from "../converters/types";
|
|
|
import { mapClientFormatToTransformer, mapProviderTypeToTransformer } from "./format-mapper";
|
|
|
import { AsyncTaskManager } from "@/lib/async-task-manager";
|
|
|
import { isClientAbortError } from "./errors";
|
|
|
+import type { SessionUsageUpdate } from "@/types/session";
|
|
|
|
|
|
export type UsageMetrics = {
|
|
|
input_tokens?: number;
|
|
|
@@ -100,6 +101,32 @@ export class ProxyResponseHandler {
|
|
|
const abortController = new AbortController();
|
|
|
|
|
|
const processingPromise = (async () => {
|
|
|
+ const finalizeNonStreamAbort = async (): Promise<void> => {
|
|
|
+ if (messageContext) {
|
|
|
+ const duration = Date.now() - session.startTime;
|
|
|
+ await updateMessageRequestDuration(messageContext.id, duration);
|
|
|
+ await updateMessageRequestDetails(messageContext.id, {
|
|
|
+ statusCode: statusCode,
|
|
|
+ providerChain: session.getProviderChain(),
|
|
|
+ });
|
|
|
+ const tracker = ProxyStatusTracker.getInstance();
|
|
|
+ tracker.endRequest(messageContext.user.id, messageContext.id);
|
|
|
+ }
|
|
|
+
|
|
|
+ if (session.sessionId) {
|
|
|
+ const sessionUsagePayload: SessionUsageUpdate = {
|
|
|
+ status: statusCode >= 200 && statusCode < 300 ? "completed" : "error",
|
|
|
+ statusCode: statusCode,
|
|
|
+ };
|
|
|
+
|
|
|
+ void SessionManager.updateSessionUsage(session.sessionId, sessionUsagePayload).catch(
|
|
|
+ (error: unknown) => {
|
|
|
+ logger.error("[ResponseHandler] Failed to update session usage:", error);
|
|
|
+ }
|
|
|
+ );
|
|
|
+ }
|
|
|
+ };
|
|
|
+
|
|
|
try {
|
|
|
// ✅ 检查客户端是否断开
|
|
|
if (session.clientAbortSignal?.aborted || abortController.signal.aborted) {
|
|
|
@@ -107,6 +134,15 @@ export class ProxyResponseHandler {
|
|
|
taskId,
|
|
|
providerId: provider.id,
|
|
|
});
|
|
|
+ try {
|
|
|
+ await finalizeNonStreamAbort();
|
|
|
+ } catch (finalizeError) {
|
|
|
+ logger.error("ResponseHandler: Failed to finalize aborted non-stream response", {
|
|
|
+ taskId,
|
|
|
+ providerId: provider.id,
|
|
|
+ finalizeError,
|
|
|
+ });
|
|
|
+ }
|
|
|
return;
|
|
|
}
|
|
|
|
|
|
@@ -208,6 +244,15 @@ export class ProxyResponseHandler {
|
|
|
? "Response transmission interrupted"
|
|
|
: "Client disconnected",
|
|
|
});
|
|
|
+ try {
|
|
|
+ await finalizeNonStreamAbort();
|
|
|
+ } catch (finalizeError) {
|
|
|
+ logger.error("ResponseHandler: Failed to finalize aborted non-stream response", {
|
|
|
+ taskId,
|
|
|
+ providerId: provider.id,
|
|
|
+ finalizeError,
|
|
|
+ });
|
|
|
+ }
|
|
|
} else {
|
|
|
logger.error("Failed to handle non-stream log:", error);
|
|
|
}
|
|
|
@@ -307,36 +352,16 @@ export class ProxyResponseHandler {
|
|
|
const reader = internalStream.getReader();
|
|
|
const decoder = new TextDecoder();
|
|
|
const chunks: string[] = [];
|
|
|
- let usageForCost: UsageMetrics | null = null;
|
|
|
-
|
|
|
- try {
|
|
|
- while (true) {
|
|
|
- // ✅ 检查取消信号
|
|
|
- if (session.clientAbortSignal?.aborted || abortController.signal.aborted) {
|
|
|
- logger.info("ResponseHandler: Stream processing cancelled", {
|
|
|
- taskId,
|
|
|
- providerId: provider.id,
|
|
|
- chunksCollected: chunks.length,
|
|
|
- });
|
|
|
- break; // 提前终止
|
|
|
- }
|
|
|
-
|
|
|
- const { value, done } = await reader.read();
|
|
|
- if (done) {
|
|
|
- break;
|
|
|
- }
|
|
|
- if (value) {
|
|
|
- chunks.push(decoder.decode(value, { stream: true }));
|
|
|
- }
|
|
|
- }
|
|
|
|
|
|
+ const flushAndJoin = (): string => {
|
|
|
const flushed = decoder.decode();
|
|
|
if (flushed) {
|
|
|
chunks.push(flushed);
|
|
|
}
|
|
|
+ return chunks.join("");
|
|
|
+ };
|
|
|
|
|
|
- const allContent = chunks.join("");
|
|
|
-
|
|
|
+ const finalizeStream = async (allContent: string): Promise<void> => {
|
|
|
// 存储响应体到 Redis(5分钟过期)
|
|
|
if (session.sessionId) {
|
|
|
void SessionManager.storeSessionResponse(session.sessionId, allContent).catch((err) => {
|
|
|
@@ -347,33 +372,31 @@ export class ProxyResponseHandler {
|
|
|
const duration = Date.now() - session.startTime;
|
|
|
await updateMessageRequestDuration(messageContext.id, duration);
|
|
|
|
|
|
- // 记录请求结束
|
|
|
const tracker = ProxyStatusTracker.getInstance();
|
|
|
tracker.endRequest(messageContext.user.id, messageContext.id);
|
|
|
|
|
|
const usageResult = parseUsageFromResponseText(allContent, provider.providerType);
|
|
|
- usageForCost = usageResult.usageMetrics;
|
|
|
+ const usageMetrics = usageResult.usageMetrics;
|
|
|
|
|
|
await updateRequestCostFromUsage(
|
|
|
messageContext.id,
|
|
|
session.getOriginalModel(),
|
|
|
session.getCurrentModel(),
|
|
|
- usageForCost,
|
|
|
+ usageMetrics,
|
|
|
provider.costMultiplier
|
|
|
);
|
|
|
|
|
|
// 追踪消费到 Redis(用于限流)
|
|
|
- await trackCostToRedis(session, usageForCost);
|
|
|
+ await trackCostToRedis(session, usageMetrics);
|
|
|
|
|
|
// 更新 session 使用量到 Redis(用于实时监控)
|
|
|
- if (session.sessionId && usageForCost) {
|
|
|
- // 计算成本(复用相同逻辑)
|
|
|
+ if (session.sessionId) {
|
|
|
let costUsdStr: string | undefined;
|
|
|
- if (session.request.model) {
|
|
|
+ if (usageMetrics && session.request.model) {
|
|
|
const priceData = await findLatestPriceByModel(session.request.model);
|
|
|
if (priceData?.priceData) {
|
|
|
const cost = calculateRequestCost(
|
|
|
- usageForCost,
|
|
|
+ usageMetrics,
|
|
|
priceData.priceData,
|
|
|
provider.costMultiplier
|
|
|
);
|
|
|
@@ -383,28 +406,63 @@ export class ProxyResponseHandler {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- void SessionManager.updateSessionUsage(session.sessionId, {
|
|
|
- inputTokens: usageForCost.input_tokens,
|
|
|
- outputTokens: usageForCost.output_tokens,
|
|
|
- cacheCreationInputTokens: usageForCost.cache_creation_input_tokens,
|
|
|
- cacheReadInputTokens: usageForCost.cache_read_input_tokens,
|
|
|
- costUsd: costUsdStr,
|
|
|
+ const sessionUsagePayload: SessionUsageUpdate = {
|
|
|
status: statusCode >= 200 && statusCode < 300 ? "completed" : "error",
|
|
|
- statusCode: statusCode,
|
|
|
- }).catch((error: unknown) => {
|
|
|
- logger.error("[ResponseHandler] Failed to update session usage:", error);
|
|
|
- });
|
|
|
+ statusCode,
|
|
|
+ };
|
|
|
+
|
|
|
+ if (usageMetrics) {
|
|
|
+ sessionUsagePayload.inputTokens = usageMetrics.input_tokens;
|
|
|
+ sessionUsagePayload.outputTokens = usageMetrics.output_tokens;
|
|
|
+ sessionUsagePayload.cacheCreationInputTokens = usageMetrics.cache_creation_input_tokens;
|
|
|
+ sessionUsagePayload.cacheReadInputTokens = usageMetrics.cache_read_input_tokens;
|
|
|
+ }
|
|
|
+
|
|
|
+ if (costUsdStr) {
|
|
|
+ sessionUsagePayload.costUsd = costUsdStr;
|
|
|
+ }
|
|
|
+
|
|
|
+ void SessionManager.updateSessionUsage(session.sessionId, sessionUsagePayload).catch(
|
|
|
+ (error: unknown) => {
|
|
|
+ logger.error("[ResponseHandler] Failed to update session usage:", error);
|
|
|
+ }
|
|
|
+ );
|
|
|
}
|
|
|
|
|
|
// 保存扩展信息(status code, tokens, provider chain)
|
|
|
await updateMessageRequestDetails(messageContext.id, {
|
|
|
statusCode: statusCode,
|
|
|
- inputTokens: usageForCost?.input_tokens,
|
|
|
- outputTokens: usageForCost?.output_tokens,
|
|
|
- cacheCreationInputTokens: usageForCost?.cache_creation_input_tokens,
|
|
|
- cacheReadInputTokens: usageForCost?.cache_read_input_tokens,
|
|
|
+ inputTokens: usageMetrics?.input_tokens,
|
|
|
+ outputTokens: usageMetrics?.output_tokens,
|
|
|
+ cacheCreationInputTokens: usageMetrics?.cache_creation_input_tokens,
|
|
|
+ cacheReadInputTokens: usageMetrics?.cache_read_input_tokens,
|
|
|
providerChain: session.getProviderChain(),
|
|
|
});
|
|
|
+ };
|
|
|
+
|
|
|
+ try {
|
|
|
+ while (true) {
|
|
|
+ // ✅ 检查取消信号
|
|
|
+ if (session.clientAbortSignal?.aborted || abortController.signal.aborted) {
|
|
|
+ logger.info("ResponseHandler: Stream processing cancelled", {
|
|
|
+ taskId,
|
|
|
+ providerId: provider.id,
|
|
|
+ chunksCollected: chunks.length,
|
|
|
+ });
|
|
|
+ break; // 提前终止
|
|
|
+ }
|
|
|
+
|
|
|
+ const { value, done } = await reader.read();
|
|
|
+ if (done) {
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ if (value) {
|
|
|
+ chunks.push(decoder.decode(value, { stream: true }));
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ const allContent = flushAndJoin();
|
|
|
+ await finalizeStream(allContent);
|
|
|
} catch (error) {
|
|
|
// 检测是否为客户端中断(使用统一的精确检测函数)
|
|
|
const err = error as Error;
|
|
|
@@ -421,6 +479,16 @@ export class ProxyResponseHandler {
|
|
|
? "Response transmission interrupted"
|
|
|
: "Client disconnected",
|
|
|
});
|
|
|
+ try {
|
|
|
+ const allContent = flushAndJoin();
|
|
|
+ await finalizeStream(allContent);
|
|
|
+ } catch (finalizeError) {
|
|
|
+ logger.error("ResponseHandler: Failed to finalize aborted stream response", {
|
|
|
+ taskId,
|
|
|
+ messageId: messageContext.id,
|
|
|
+ finalizeError,
|
|
|
+ });
|
|
|
+ }
|
|
|
} else {
|
|
|
logger.error("Failed to save SSE content:", error);
|
|
|
}
|