response-handler-gemini-stream-passthrough-timeouts.test.ts 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460
  1. import { createServer } from "node:http";
  2. import type { Socket } from "node:net";
  3. import { beforeEach, describe, expect, test, vi } from "vitest";
  4. import { ProxyForwarder } from "@/app/v1/_lib/proxy/forwarder";
  5. import { resolveEndpointPolicy } from "@/app/v1/_lib/proxy/endpoint-policy";
  6. import { ProxyResponseHandler } from "@/app/v1/_lib/proxy/response-handler";
  7. import { ProxySession } from "@/app/v1/_lib/proxy/session";
  8. import type { Provider } from "@/types/provider";
  9. const asyncTasks: Promise<void>[] = [];
  10. const mocks = vi.hoisted(() => {
  11. return {
  12. isHttp2Enabled: vi.fn(async () => false),
  13. };
  14. });
  15. beforeEach(() => {
  16. mocks.isHttp2Enabled.mockReset();
  17. mocks.isHttp2Enabled.mockResolvedValue(false);
  18. });
  19. vi.mock("@/lib/config", async (importOriginal) => {
  20. const actual = await importOriginal<typeof import("@/lib/config")>();
  21. return {
  22. ...actual,
  23. isHttp2Enabled: mocks.isHttp2Enabled,
  24. };
  25. });
  26. vi.mock("@/app/v1/_lib/proxy/response-fixer", () => ({
  27. ResponseFixer: {
  28. process: async (_session: unknown, response: Response) => response,
  29. },
  30. }));
  31. vi.mock("@/lib/async-task-manager", () => ({
  32. AsyncTaskManager: {
  33. register: (_taskId: string, promise: Promise<void>) => {
  34. asyncTasks.push(promise);
  35. return new AbortController();
  36. },
  37. cleanup: () => {},
  38. cancel: () => {},
  39. },
  40. }));
  41. vi.mock("@/lib/logger", () => ({
  42. logger: {
  43. debug: vi.fn(),
  44. info: vi.fn(),
  45. warn: vi.fn(),
  46. trace: vi.fn(),
  47. error: vi.fn(),
  48. },
  49. }));
  50. vi.mock("@/repository/message", () => ({
  51. updateMessageRequestCost: vi.fn(),
  52. updateMessageRequestDetails: vi.fn(),
  53. updateMessageRequestDuration: vi.fn(),
  54. }));
  55. vi.mock("@/repository/system-config", () => ({
  56. getSystemSettings: vi.fn(async () => ({ billingModelSource: "original" })),
  57. }));
  58. vi.mock("@/repository/model-price", () => ({
  59. findLatestPriceByModel: vi.fn(async () => ({
  60. priceData: { input_cost_per_token: 0, output_cost_per_token: 0 },
  61. })),
  62. }));
  63. vi.mock("@/lib/session-manager", () => ({
  64. SessionManager: {
  65. storeSessionResponse: vi.fn(),
  66. updateSessionUsage: vi.fn(),
  67. },
  68. }));
  69. vi.mock("@/lib/proxy-status-tracker", () => ({
  70. ProxyStatusTracker: {
  71. getInstance: () => ({
  72. endRequest: () => {},
  73. }),
  74. },
  75. }));
  76. function createProvider(overrides: Partial<Provider> = {}): Provider {
  77. return {
  78. id: 1,
  79. name: "p1",
  80. url: "http://127.0.0.1:1",
  81. key: "k",
  82. providerVendorId: null,
  83. isEnabled: true,
  84. weight: 1,
  85. priority: 0,
  86. groupPriorities: null,
  87. costMultiplier: 1,
  88. groupTag: null,
  89. providerType: "gemini",
  90. preserveClientIp: false,
  91. modelRedirects: null,
  92. allowedModels: null,
  93. mcpPassthroughType: "none",
  94. mcpPassthroughUrl: null,
  95. limit5hUsd: null,
  96. limitDailyUsd: null,
  97. dailyResetMode: "fixed",
  98. dailyResetTime: "00:00",
  99. limitWeeklyUsd: null,
  100. limitMonthlyUsd: null,
  101. limitTotalUsd: null,
  102. totalCostResetAt: null,
  103. limitConcurrentSessions: 0,
  104. maxRetryAttempts: null,
  105. circuitBreakerFailureThreshold: 5,
  106. circuitBreakerOpenDuration: 1_800_000,
  107. circuitBreakerHalfOpenSuccessThreshold: 2,
  108. proxyUrl: null,
  109. proxyFallbackToDirect: false,
  110. firstByteTimeoutStreamingMs: 100,
  111. streamingIdleTimeoutMs: 0,
  112. requestTimeoutNonStreamingMs: 0,
  113. websiteUrl: null,
  114. faviconUrl: null,
  115. cacheTtlPreference: null,
  116. context1mPreference: null,
  117. codexReasoningEffortPreference: null,
  118. codexReasoningSummaryPreference: null,
  119. codexTextVerbosityPreference: null,
  120. codexParallelToolCallsPreference: null,
  121. anthropicMaxTokensPreference: null,
  122. anthropicThinkingBudgetPreference: null,
  123. geminiGoogleSearchPreference: null,
  124. tpm: 0,
  125. rpm: 0,
  126. rpd: 0,
  127. cc: 0,
  128. createdAt: new Date(),
  129. updatedAt: new Date(),
  130. deletedAt: null,
  131. ...overrides,
  132. };
  133. }
  134. function createSession(params: {
  135. clientAbortSignal: AbortSignal;
  136. messageId: number;
  137. userId: number;
  138. }): ProxySession {
  139. const headers = new Headers();
  140. const session = Object.create(ProxySession.prototype);
  141. Object.assign(session, {
  142. startTime: Date.now(),
  143. method: "POST",
  144. requestUrl: new URL("https://example.com/v1/chat/completions"),
  145. headers,
  146. originalHeaders: new Headers(headers),
  147. headerLog: JSON.stringify(Object.fromEntries(headers.entries())),
  148. request: {
  149. model: "gemini-2.0-flash",
  150. log: "(test)",
  151. message: {
  152. model: "gemini-2.0-flash",
  153. stream: true,
  154. messages: [{ role: "user", content: "hi" }],
  155. },
  156. },
  157. userAgent: null,
  158. context: null,
  159. clientAbortSignal: params.clientAbortSignal,
  160. userName: "test-user",
  161. authState: { success: true, user: null, key: null, apiKey: null },
  162. provider: null,
  163. messageContext: {
  164. id: params.messageId,
  165. createdAt: new Date(),
  166. user: { id: params.userId, name: "u1" },
  167. },
  168. sessionId: null,
  169. requestSequence: 1,
  170. originalFormat: "gemini",
  171. providerType: null,
  172. originalModelName: null,
  173. originalUrlPathname: null,
  174. providerChain: [],
  175. cacheTtlResolved: null,
  176. context1mApplied: false,
  177. specialSettings: [],
  178. cachedPriceData: undefined,
  179. cachedBillingModelSource: undefined,
  180. endpointPolicy: resolveEndpointPolicy("/v1/chat/completions"),
  181. isHeaderModified: () => false,
  182. });
  183. return session as ProxySession;
  184. }
  185. async function startSseServer(handler: Parameters<typeof createServer>[0]): Promise<{
  186. baseUrl: string;
  187. close: () => Promise<void>;
  188. }> {
  189. const sockets = new Set<Socket>();
  190. const server = createServer(handler);
  191. server.on("connection", (socket) => {
  192. sockets.add(socket);
  193. socket.on("close", () => sockets.delete(socket));
  194. });
  195. const baseUrl = await new Promise<string>((resolve, reject) => {
  196. server.once("error", reject);
  197. server.listen(0, "127.0.0.1", () => {
  198. const addr = server.address();
  199. if (!addr || typeof addr === "string") {
  200. reject(new Error("Failed to get server address"));
  201. return;
  202. }
  203. resolve(`http://127.0.0.1:${addr.port}`);
  204. });
  205. });
  206. const close = async () => {
  207. for (const socket of sockets) {
  208. try {
  209. socket.destroy();
  210. } catch {
  211. // ignore
  212. }
  213. }
  214. sockets.clear();
  215. await new Promise<void>((resolve) => server.close(() => resolve()));
  216. };
  217. return { baseUrl, close };
  218. }
  219. async function readWithTimeout(
  220. reader: ReadableStreamDefaultReader<Uint8Array>,
  221. timeoutMs: number
  222. ): Promise<
  223. | { ok: true; value: ReadableStreamReadResult<Uint8Array> }
  224. | { ok: true; error: unknown }
  225. | { ok: false; reason: "timeout" }
  226. > {
  227. const result = await Promise.race([
  228. reader
  229. .read()
  230. .then((value) => ({ ok: true as const, value }))
  231. .catch((error) => ({ ok: true as const, error })),
  232. new Promise<{ ok: false; reason: "timeout" }>((resolve) =>
  233. setTimeout(() => resolve({ ok: false as const, reason: "timeout" }), timeoutMs)
  234. ),
  235. ]);
  236. return result;
  237. }
  238. describe("ProxyResponseHandler - Gemini stream passthrough timeouts", () => {
  239. test("不应在仅收到 headers 时清除首字节超时:无首块数据时应在窗口内中断避免悬挂", async () => {
  240. asyncTasks.length = 0;
  241. const { baseUrl, close } = await startSseServer((_req, res) => {
  242. res.writeHead(200, {
  243. "content-type": "text/event-stream",
  244. "cache-control": "no-cache",
  245. connection: "keep-alive",
  246. });
  247. res.flushHeaders();
  248. // 不发送任何 body,保持连接不结束
  249. });
  250. const clientAbortController = new AbortController();
  251. try {
  252. const provider = createProvider({
  253. url: baseUrl,
  254. firstByteTimeoutStreamingMs: 200,
  255. });
  256. const session = createSession({
  257. clientAbortSignal: clientAbortController.signal,
  258. messageId: 1,
  259. userId: 1,
  260. });
  261. session.setProvider(provider);
  262. const doForward = (
  263. ProxyForwarder as unknown as {
  264. doForward: (this: typeof ProxyForwarder, ...args: unknown[]) => unknown;
  265. }
  266. ).doForward;
  267. const upstreamResponse = (await doForward.call(
  268. ProxyForwarder,
  269. session,
  270. provider,
  271. baseUrl
  272. )) as Response;
  273. const clientResponse = await ProxyResponseHandler.dispatch(session, upstreamResponse);
  274. const reader = clientResponse.body?.getReader();
  275. expect(reader).toBeTruthy();
  276. if (!reader) throw new Error("Missing body reader");
  277. const startedAt = Date.now();
  278. const firstRead = await readWithTimeout(reader, 1500);
  279. if (!firstRead.ok) {
  280. clientAbortController.abort(new Error("test_timeout"));
  281. throw new Error("首字节超时未生效:读首块数据在 1.5s 内仍未返回(可能仍会卡死)");
  282. }
  283. // 断言:应由超时/中断导致读取结束(done=true 或抛错均可)
  284. const ended = ("value" in firstRead && firstRead.value.done === true) || "error" in firstRead;
  285. expect(ended).toBe(true);
  286. // 断言:responseController 应已触发 abort(即首字节超时生效)
  287. const sessionWithController = session as unknown as { responseController?: AbortController };
  288. expect(sessionWithController.responseController?.signal.aborted).toBe(true);
  289. // 粗略时间断言:不应立即返回(避免“无关早退”导致假阳性)
  290. const elapsed = Date.now() - startedAt;
  291. expect(elapsed).toBeGreaterThanOrEqual(120);
  292. } finally {
  293. clientAbortController.abort(new Error("test_cleanup"));
  294. await close();
  295. await Promise.allSettled(asyncTasks);
  296. }
  297. });
  298. test("收到首块数据后应清除首字节超时:后续 chunk 即使晚于 firstByteTimeout 也不应被误中断", async () => {
  299. asyncTasks.length = 0;
  300. const { baseUrl, close } = await startSseServer((_req, res) => {
  301. res.writeHead(200, {
  302. "content-type": "text/event-stream",
  303. "cache-control": "no-cache",
  304. connection: "keep-alive",
  305. });
  306. res.flushHeaders();
  307. res.write('data: {"x":1}\n\n');
  308. setTimeout(() => {
  309. try {
  310. res.write('data: {"x":2}\n\n');
  311. res.end();
  312. } catch {
  313. // ignore
  314. }
  315. }, 150);
  316. });
  317. const clientAbortController = new AbortController();
  318. try {
  319. const provider = createProvider({
  320. url: baseUrl,
  321. firstByteTimeoutStreamingMs: 100,
  322. streamingIdleTimeoutMs: 0,
  323. });
  324. const session = createSession({
  325. clientAbortSignal: clientAbortController.signal,
  326. messageId: 2,
  327. userId: 1,
  328. });
  329. session.setProvider(provider);
  330. const doForward = (
  331. ProxyForwarder as unknown as {
  332. doForward: (this: typeof ProxyForwarder, ...args: unknown[]) => unknown;
  333. }
  334. ).doForward;
  335. const upstreamResponse = (await doForward.call(
  336. ProxyForwarder,
  337. session,
  338. provider,
  339. baseUrl
  340. )) as Response;
  341. const clientResponse = await ProxyResponseHandler.dispatch(session, upstreamResponse);
  342. const fullText = await Promise.race([
  343. clientResponse.text(),
  344. new Promise<"timeout">((resolve) => setTimeout(() => resolve("timeout"), 1500)),
  345. ]);
  346. if (fullText === "timeout") {
  347. clientAbortController.abort(new Error("test_timeout"));
  348. throw new Error("读取透传响应超时(可能仍会卡死)");
  349. }
  350. // 第二块数据在 150ms 发送,若首字节超时未被清除,则 100ms 左右就会被中断拿不到第二块
  351. expect(fullText).toContain('"x":2');
  352. } finally {
  353. clientAbortController.abort(new Error("test_cleanup"));
  354. await close();
  355. await Promise.allSettled(asyncTasks);
  356. }
  357. });
  358. test("中途静默超过 streamingIdleTimeoutMs 时应中断,避免 200 跑到一半卡死", async () => {
  359. asyncTasks.length = 0;
  360. const { baseUrl, close } = await startSseServer((_req, res) => {
  361. res.writeHead(200, {
  362. "content-type": "text/event-stream",
  363. "cache-control": "no-cache",
  364. connection: "keep-alive",
  365. });
  366. res.flushHeaders();
  367. res.write('data: {"x":1}\n\n');
  368. // 不再发送数据,也不结束连接
  369. });
  370. const clientAbortController = new AbortController();
  371. try {
  372. const provider = createProvider({
  373. url: baseUrl,
  374. firstByteTimeoutStreamingMs: 1000,
  375. streamingIdleTimeoutMs: 120,
  376. });
  377. const session = createSession({
  378. clientAbortSignal: clientAbortController.signal,
  379. messageId: 3,
  380. userId: 1,
  381. });
  382. session.setProvider(provider);
  383. const doForward = (
  384. ProxyForwarder as unknown as {
  385. doForward: (this: typeof ProxyForwarder, ...args: unknown[]) => unknown;
  386. }
  387. ).doForward;
  388. const upstreamResponse = (await doForward.call(
  389. ProxyForwarder,
  390. session,
  391. provider,
  392. baseUrl
  393. )) as Response;
  394. const clientResponse = await ProxyResponseHandler.dispatch(session, upstreamResponse);
  395. const reader = clientResponse.body?.getReader();
  396. expect(reader).toBeTruthy();
  397. if (!reader) throw new Error("Missing body reader");
  398. const first = await readWithTimeout(reader, 1000);
  399. expect(first.ok).toBe(true);
  400. if (!("value" in first)) {
  401. throw new Error("首块数据读取异常:预期拿到 value,但得到 error");
  402. }
  403. expect(first.value.done).toBe(false);
  404. // 静默超时触发后,后续 read 应该在合理时间内结束(done=true 或抛错均可)
  405. const second = await readWithTimeout(reader, 1500);
  406. if (!second.ok) {
  407. clientAbortController.abort(new Error("test_timeout"));
  408. throw new Error("流式静默超时未生效:读后续数据在 1.5s 内仍未返回(可能仍会卡死)");
  409. }
  410. } finally {
  411. clientAbortController.abort(new Error("test_cleanup"));
  412. await close();
  413. await Promise.allSettled(asyncTasks);
  414. }
  415. });
  416. });