response-handler-gemini-stream-passthrough-timeouts.test.ts 14 KB


  1. import { createServer } from "node:http";
  2. import type { Socket } from "node:net";
  3. import { beforeEach, describe, expect, test, vi } from "vitest";
  4. import { ProxyForwarder } from "@/app/v1/_lib/proxy/forwarder";
  5. import { ProxyResponseHandler } from "@/app/v1/_lib/proxy/response-handler";
  6. import { ProxySession } from "@/app/v1/_lib/proxy/session";
  7. import type { Provider } from "@/types/provider";
  8. const asyncTasks: Promise<void>[] = [];
  9. const mocks = vi.hoisted(() => {
  10. return {
  11. isHttp2Enabled: vi.fn(async () => false),
  12. };
  13. });
  14. beforeEach(() => {
  15. mocks.isHttp2Enabled.mockReset();
  16. mocks.isHttp2Enabled.mockResolvedValue(false);
  17. });
  18. vi.mock("@/lib/config", async (importOriginal) => {
  19. const actual = await importOriginal<typeof import("@/lib/config")>();
  20. return {
  21. ...actual,
  22. isHttp2Enabled: mocks.isHttp2Enabled,
  23. };
  24. });
  25. vi.mock("@/app/v1/_lib/proxy/response-fixer", () => ({
  26. ResponseFixer: {
  27. process: async (_session: unknown, response: Response) => response,
  28. },
  29. }));
  30. vi.mock("@/lib/async-task-manager", () => ({
  31. AsyncTaskManager: {
  32. register: (_taskId: string, promise: Promise<void>) => {
  33. asyncTasks.push(promise);
  34. return new AbortController();
  35. },
  36. cleanup: () => {},
  37. cancel: () => {},
  38. },
  39. }));
  40. vi.mock("@/lib/logger", () => ({
  41. logger: {
  42. debug: vi.fn(),
  43. info: vi.fn(),
  44. warn: vi.fn(),
  45. trace: vi.fn(),
  46. error: vi.fn(),
  47. },
  48. }));
  49. vi.mock("@/repository/message", () => ({
  50. updateMessageRequestCost: vi.fn(),
  51. updateMessageRequestDetails: vi.fn(),
  52. updateMessageRequestDuration: vi.fn(),
  53. }));
  54. vi.mock("@/repository/system-config", () => ({
  55. getSystemSettings: vi.fn(async () => ({ billingModelSource: "original" })),
  56. }));
  57. vi.mock("@/repository/model-price", () => ({
  58. findLatestPriceByModel: vi.fn(async () => ({
  59. priceData: { input_cost_per_token: 0, output_cost_per_token: 0 },
  60. })),
  61. }));
  62. vi.mock("@/lib/session-manager", () => ({
  63. SessionManager: {
  64. storeSessionResponse: vi.fn(),
  65. updateSessionUsage: vi.fn(),
  66. },
  67. }));
  68. vi.mock("@/lib/proxy-status-tracker", () => ({
  69. ProxyStatusTracker: {
  70. getInstance: () => ({
  71. endRequest: () => {},
  72. }),
  73. },
  74. }));
  75. function createProvider(overrides: Partial<Provider> = {}): Provider {
  76. return {
  77. id: 1,
  78. name: "p1",
  79. url: "http://127.0.0.1:1",
  80. key: "k",
  81. providerVendorId: null,
  82. isEnabled: true,
  83. weight: 1,
  84. priority: 0,
  85. groupPriorities: null,
  86. costMultiplier: 1,
  87. groupTag: null,
  88. providerType: "gemini",
  89. preserveClientIp: false,
  90. modelRedirects: null,
  91. allowedModels: null,
  92. mcpPassthroughType: "none",
  93. mcpPassthroughUrl: null,
  94. limit5hUsd: null,
  95. limitDailyUsd: null,
  96. dailyResetMode: "fixed",
  97. dailyResetTime: "00:00",
  98. limitWeeklyUsd: null,
  99. limitMonthlyUsd: null,
  100. limitTotalUsd: null,
  101. totalCostResetAt: null,
  102. limitConcurrentSessions: 0,
  103. maxRetryAttempts: null,
  104. circuitBreakerFailureThreshold: 5,
  105. circuitBreakerOpenDuration: 1_800_000,
  106. circuitBreakerHalfOpenSuccessThreshold: 2,
  107. proxyUrl: null,
  108. proxyFallbackToDirect: false,
  109. firstByteTimeoutStreamingMs: 100,
  110. streamingIdleTimeoutMs: 0,
  111. requestTimeoutNonStreamingMs: 0,
  112. websiteUrl: null,
  113. faviconUrl: null,
  114. cacheTtlPreference: null,
  115. context1mPreference: null,
  116. codexReasoningEffortPreference: null,
  117. codexReasoningSummaryPreference: null,
  118. codexTextVerbosityPreference: null,
  119. codexParallelToolCallsPreference: null,
  120. anthropicMaxTokensPreference: null,
  121. anthropicThinkingBudgetPreference: null,
  122. geminiGoogleSearchPreference: null,
  123. tpm: 0,
  124. rpm: 0,
  125. rpd: 0,
  126. cc: 0,
  127. createdAt: new Date(),
  128. updatedAt: new Date(),
  129. deletedAt: null,
  130. ...overrides,
  131. };
  132. }
  133. function createSession(params: {
  134. clientAbortSignal: AbortSignal;
  135. messageId: number;
  136. userId: number;
  137. }): ProxySession {
  138. const headers = new Headers();
  139. const session = Object.create(ProxySession.prototype);
  140. Object.assign(session, {
  141. startTime: Date.now(),
  142. method: "POST",
  143. requestUrl: new URL("https://example.com/v1/chat/completions"),
  144. headers,
  145. originalHeaders: new Headers(headers),
  146. headerLog: JSON.stringify(Object.fromEntries(headers.entries())),
  147. request: {
  148. model: "gemini-2.0-flash",
  149. log: "(test)",
  150. message: {
  151. model: "gemini-2.0-flash",
  152. stream: true,
  153. messages: [{ role: "user", content: "hi" }],
  154. },
  155. },
  156. userAgent: null,
  157. context: null,
  158. clientAbortSignal: params.clientAbortSignal,
  159. userName: "test-user",
  160. authState: { success: true, user: null, key: null, apiKey: null },
  161. provider: null,
  162. messageContext: {
  163. id: params.messageId,
  164. createdAt: new Date(),
  165. user: { id: params.userId, name: "u1" },
  166. },
  167. sessionId: null,
  168. requestSequence: 1,
  169. originalFormat: "gemini",
  170. providerType: null,
  171. originalModelName: null,
  172. originalUrlPathname: null,
  173. providerChain: [],
  174. cacheTtlResolved: null,
  175. context1mApplied: false,
  176. specialSettings: [],
  177. cachedPriceData: undefined,
  178. cachedBillingModelSource: undefined,
  179. isHeaderModified: () => false,
  180. });
  181. return session as ProxySession;
  182. }
  183. async function startSseServer(handler: Parameters<typeof createServer>[0]): Promise<{
  184. baseUrl: string;
  185. close: () => Promise<void>;
  186. }> {
  187. const sockets = new Set<Socket>();
  188. const server = createServer(handler);
  189. server.on("connection", (socket) => {
  190. sockets.add(socket);
  191. socket.on("close", () => sockets.delete(socket));
  192. });
  193. const baseUrl = await new Promise<string>((resolve, reject) => {
  194. server.once("error", reject);
  195. server.listen(0, "127.0.0.1", () => {
  196. const addr = server.address();
  197. if (!addr || typeof addr === "string") {
  198. reject(new Error("Failed to get server address"));
  199. return;
  200. }
  201. resolve(`http://127.0.0.1:${addr.port}`);
  202. });
  203. });
  204. const close = async () => {
  205. for (const socket of sockets) {
  206. try {
  207. socket.destroy();
  208. } catch {
  209. // ignore
  210. }
  211. }
  212. sockets.clear();
  213. await new Promise<void>((resolve) => server.close(() => resolve()));
  214. };
  215. return { baseUrl, close };
  216. }
  217. async function readWithTimeout(
  218. reader: ReadableStreamDefaultReader<Uint8Array>,
  219. timeoutMs: number
  220. ): Promise<
  221. | { ok: true; value: ReadableStreamReadResult<Uint8Array> }
  222. | { ok: true; error: unknown }
  223. | { ok: false; reason: "timeout" }
  224. > {
  225. const result = await Promise.race([
  226. reader
  227. .read()
  228. .then((value) => ({ ok: true as const, value }))
  229. .catch((error) => ({ ok: true as const, error })),
  230. new Promise<{ ok: false; reason: "timeout" }>((resolve) =>
  231. setTimeout(() => resolve({ ok: false as const, reason: "timeout" }), timeoutMs)
  232. ),
  233. ]);
  234. return result;
  235. }
  236. describe("ProxyResponseHandler - Gemini stream passthrough timeouts", () => {
  237. test("不应在仅收到 headers 时清除首字节超时:无首块数据时应在窗口内中断避免悬挂", async () => {
  238. asyncTasks.length = 0;
  239. const { baseUrl, close } = await startSseServer((_req, res) => {
  240. res.writeHead(200, {
  241. "content-type": "text/event-stream",
  242. "cache-control": "no-cache",
  243. connection: "keep-alive",
  244. });
  245. res.flushHeaders();
  246. // 不发送任何 body,保持连接不结束
  247. });
  248. const clientAbortController = new AbortController();
  249. try {
  250. const provider = createProvider({
  251. url: baseUrl,
  252. firstByteTimeoutStreamingMs: 200,
  253. });
  254. const session = createSession({
  255. clientAbortSignal: clientAbortController.signal,
  256. messageId: 1,
  257. userId: 1,
  258. });
  259. session.setProvider(provider);
  260. const doForward = (
  261. ProxyForwarder as unknown as {
  262. doForward: (this: typeof ProxyForwarder, ...args: unknown[]) => unknown;
  263. }
  264. ).doForward;
  265. const upstreamResponse = (await doForward.call(
  266. ProxyForwarder,
  267. session,
  268. provider,
  269. baseUrl
  270. )) as Response;
  271. const clientResponse = await ProxyResponseHandler.dispatch(session, upstreamResponse);
  272. const reader = clientResponse.body?.getReader();
  273. expect(reader).toBeTruthy();
  274. if (!reader) throw new Error("Missing body reader");
  275. const startedAt = Date.now();
  276. const firstRead = await readWithTimeout(reader, 1500);
  277. if (!firstRead.ok) {
  278. clientAbortController.abort(new Error("test_timeout"));
  279. throw new Error("首字节超时未生效:读首块数据在 1.5s 内仍未返回(可能仍会卡死)");
  280. }
  281. // 断言:应由超时/中断导致读取结束(done=true 或抛错均可)
  282. const ended = ("value" in firstRead && firstRead.value.done === true) || "error" in firstRead;
  283. expect(ended).toBe(true);
  284. // 断言:responseController 应已触发 abort(即首字节超时生效)
  285. const sessionWithController = session as unknown as { responseController?: AbortController };
  286. expect(sessionWithController.responseController?.signal.aborted).toBe(true);
  287. // 粗略时间断言:不应立即返回(避免“无关早退”导致假阳性)
  288. const elapsed = Date.now() - startedAt;
  289. expect(elapsed).toBeGreaterThanOrEqual(120);
  290. } finally {
  291. clientAbortController.abort(new Error("test_cleanup"));
  292. await close();
  293. await Promise.allSettled(asyncTasks);
  294. }
  295. });
  296. test("收到首块数据后应清除首字节超时:后续 chunk 即使晚于 firstByteTimeout 也不应被误中断", async () => {
  297. asyncTasks.length = 0;
  298. const { baseUrl, close } = await startSseServer((_req, res) => {
  299. res.writeHead(200, {
  300. "content-type": "text/event-stream",
  301. "cache-control": "no-cache",
  302. connection: "keep-alive",
  303. });
  304. res.flushHeaders();
  305. res.write('data: {"x":1}\n\n');
  306. setTimeout(() => {
  307. try {
  308. res.write('data: {"x":2}\n\n');
  309. res.end();
  310. } catch {
  311. // ignore
  312. }
  313. }, 150);
  314. });
  315. const clientAbortController = new AbortController();
  316. try {
  317. const provider = createProvider({
  318. url: baseUrl,
  319. firstByteTimeoutStreamingMs: 100,
  320. streamingIdleTimeoutMs: 0,
  321. });
  322. const session = createSession({
  323. clientAbortSignal: clientAbortController.signal,
  324. messageId: 2,
  325. userId: 1,
  326. });
  327. session.setProvider(provider);
  328. const doForward = (
  329. ProxyForwarder as unknown as {
  330. doForward: (this: typeof ProxyForwarder, ...args: unknown[]) => unknown;
  331. }
  332. ).doForward;
  333. const upstreamResponse = (await doForward.call(
  334. ProxyForwarder,
  335. session,
  336. provider,
  337. baseUrl
  338. )) as Response;
  339. const clientResponse = await ProxyResponseHandler.dispatch(session, upstreamResponse);
  340. const fullText = await Promise.race([
  341. clientResponse.text(),
  342. new Promise<"timeout">((resolve) => setTimeout(() => resolve("timeout"), 1500)),
  343. ]);
  344. if (fullText === "timeout") {
  345. clientAbortController.abort(new Error("test_timeout"));
  346. throw new Error("读取透传响应超时(可能仍会卡死)");
  347. }
  348. // 第二块数据在 150ms 发送,若首字节超时未被清除,则 100ms 左右就会被中断拿不到第二块
  349. expect(fullText).toContain('"x":2');
  350. } finally {
  351. clientAbortController.abort(new Error("test_cleanup"));
  352. await close();
  353. await Promise.allSettled(asyncTasks);
  354. }
  355. });
  356. test("中途静默超过 streamingIdleTimeoutMs 时应中断,避免 200 跑到一半卡死", async () => {
  357. asyncTasks.length = 0;
  358. const { baseUrl, close } = await startSseServer((_req, res) => {
  359. res.writeHead(200, {
  360. "content-type": "text/event-stream",
  361. "cache-control": "no-cache",
  362. connection: "keep-alive",
  363. });
  364. res.flushHeaders();
  365. res.write('data: {"x":1}\n\n');
  366. // 不再发送数据,也不结束连接
  367. });
  368. const clientAbortController = new AbortController();
  369. try {
  370. const provider = createProvider({
  371. url: baseUrl,
  372. firstByteTimeoutStreamingMs: 1000,
  373. streamingIdleTimeoutMs: 120,
  374. });
  375. const session = createSession({
  376. clientAbortSignal: clientAbortController.signal,
  377. messageId: 3,
  378. userId: 1,
  379. });
  380. session.setProvider(provider);
  381. const doForward = (
  382. ProxyForwarder as unknown as {
  383. doForward: (this: typeof ProxyForwarder, ...args: unknown[]) => unknown;
  384. }
  385. ).doForward;
  386. const upstreamResponse = (await doForward.call(
  387. ProxyForwarder,
  388. session,
  389. provider,
  390. baseUrl
  391. )) as Response;
  392. const clientResponse = await ProxyResponseHandler.dispatch(session, upstreamResponse);
  393. const reader = clientResponse.body?.getReader();
  394. expect(reader).toBeTruthy();
  395. if (!reader) throw new Error("Missing body reader");
  396. const first = await readWithTimeout(reader, 1000);
  397. expect(first.ok).toBe(true);
  398. if (!("value" in first)) {
  399. throw new Error("首块数据读取异常:预期拿到 value,但得到 error");
  400. }
  401. expect(first.value.done).toBe(false);
  402. // 静默超时触发后,后续 read 应该在合理时间内结束(done=true 或抛错均可)
  403. const second = await readWithTimeout(reader, 1500);
  404. if (!second.ok) {
  405. clientAbortController.abort(new Error("test_timeout"));
  406. throw new Error("流式静默超时未生效:读后续数据在 1.5s 内仍未返回(可能仍会卡死)");
  407. }
  408. } finally {
  409. clientAbortController.abort(new Error("test_cleanup"));
  410. await close();
  411. await Promise.allSettled(asyncTasks);
  412. }
  413. });
  414. });