proxy-forwarder-hedge-first-byte.test.ts 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915
  1. import { beforeEach, describe, expect, test, vi } from "vitest";
  2. import { resolveEndpointPolicy } from "@/app/v1/_lib/proxy/endpoint-policy";
  3. const mocks = vi.hoisted(() => ({
  4. pickRandomProviderWithExclusion: vi.fn(),
  5. recordSuccess: vi.fn(),
  6. recordFailure: vi.fn(async () => {}),
  7. getCircuitState: vi.fn(() => "closed"),
  8. getProviderHealthInfo: vi.fn(async () => ({
  9. health: { failureCount: 0 },
  10. config: { failureThreshold: 3 },
  11. })),
  12. updateSessionBindingSmart: vi.fn(async () => ({ updated: true, reason: "test" })),
  13. updateSessionProvider: vi.fn(async () => {}),
  14. clearSessionProvider: vi.fn(async () => {}),
  15. isHttp2Enabled: vi.fn(async () => false),
  16. getPreferredProviderEndpoints: vi.fn(async () => []),
  17. getEndpointFilterStats: vi.fn(async () => null),
  18. recordEndpointSuccess: vi.fn(async () => {}),
  19. recordEndpointFailure: vi.fn(async () => {}),
  20. isVendorTypeCircuitOpen: vi.fn(async () => false),
  21. recordVendorTypeAllEndpointsTimeout: vi.fn(async () => {}),
  22. categorizeErrorAsync: vi.fn(async () => 0),
  23. storeSessionSpecialSettings: vi.fn(async () => {}),
  24. }));
  25. vi.mock("@/lib/logger", () => ({
  26. logger: {
  27. debug: vi.fn(),
  28. info: vi.fn(),
  29. warn: vi.fn(),
  30. trace: vi.fn(),
  31. error: vi.fn(),
  32. fatal: vi.fn(),
  33. },
  34. }));
  35. vi.mock("@/lib/config", async (importOriginal) => {
  36. const actual = await importOriginal<typeof import("@/lib/config")>();
  37. return {
  38. ...actual,
  39. isHttp2Enabled: mocks.isHttp2Enabled,
  40. };
  41. });
  42. vi.mock("@/lib/provider-endpoints/endpoint-selector", () => ({
  43. getPreferredProviderEndpoints: mocks.getPreferredProviderEndpoints,
  44. getEndpointFilterStats: mocks.getEndpointFilterStats,
  45. }));
  46. vi.mock("@/lib/endpoint-circuit-breaker", () => ({
  47. recordEndpointSuccess: mocks.recordEndpointSuccess,
  48. recordEndpointFailure: mocks.recordEndpointFailure,
  49. }));
  50. vi.mock("@/lib/circuit-breaker", () => ({
  51. getCircuitState: mocks.getCircuitState,
  52. getProviderHealthInfo: mocks.getProviderHealthInfo,
  53. recordFailure: mocks.recordFailure,
  54. recordSuccess: mocks.recordSuccess,
  55. }));
  56. vi.mock("@/lib/vendor-type-circuit-breaker", () => ({
  57. isVendorTypeCircuitOpen: mocks.isVendorTypeCircuitOpen,
  58. recordVendorTypeAllEndpointsTimeout: mocks.recordVendorTypeAllEndpointsTimeout,
  59. }));
  60. vi.mock("@/lib/session-manager", () => ({
  61. SessionManager: {
  62. updateSessionBindingSmart: mocks.updateSessionBindingSmart,
  63. updateSessionProvider: mocks.updateSessionProvider,
  64. clearSessionProvider: mocks.clearSessionProvider,
  65. storeSessionSpecialSettings: mocks.storeSessionSpecialSettings,
  66. },
  67. }));
  68. vi.mock("@/app/v1/_lib/proxy/provider-selector", () => ({
  69. ProxyProviderResolver: {
  70. pickRandomProviderWithExclusion: mocks.pickRandomProviderWithExclusion,
  71. },
  72. }));
  73. vi.mock("@/app/v1/_lib/proxy/errors", async (importOriginal) => {
  74. const actual = await importOriginal<typeof import("@/app/v1/_lib/proxy/errors")>();
  75. return {
  76. ...actual,
  77. categorizeErrorAsync: mocks.categorizeErrorAsync,
  78. };
  79. });
  80. import {
  81. ErrorCategory as ProxyErrorCategory,
  82. ProxyError as UpstreamProxyError,
  83. } from "@/app/v1/_lib/proxy/errors";
  84. import { ProxyForwarder } from "@/app/v1/_lib/proxy/forwarder";
  85. import { ProxySession } from "@/app/v1/_lib/proxy/session";
  86. import type { Provider } from "@/types/provider";
  87. type AttemptRuntime = {
  88. clearResponseTimeout?: () => void;
  89. responseController?: AbortController;
  90. };
  91. function createProvider(overrides: Partial<Provider> = {}): Provider {
  92. return {
  93. id: 1,
  94. name: "p1",
  95. url: "https://provider.example.com",
  96. key: "k",
  97. providerVendorId: null,
  98. isEnabled: true,
  99. weight: 1,
  100. priority: 0,
  101. groupPriorities: null,
  102. costMultiplier: 1,
  103. groupTag: null,
  104. providerType: "claude",
  105. preserveClientIp: false,
  106. modelRedirects: null,
  107. allowedModels: null,
  108. mcpPassthroughType: "none",
  109. mcpPassthroughUrl: null,
  110. limit5hUsd: null,
  111. limitDailyUsd: null,
  112. dailyResetMode: "fixed",
  113. dailyResetTime: "00:00",
  114. limitWeeklyUsd: null,
  115. limitMonthlyUsd: null,
  116. limitTotalUsd: null,
  117. totalCostResetAt: null,
  118. limitConcurrentSessions: 0,
  119. maxRetryAttempts: 1,
  120. circuitBreakerFailureThreshold: 5,
  121. circuitBreakerOpenDuration: 1_800_000,
  122. circuitBreakerHalfOpenSuccessThreshold: 2,
  123. proxyUrl: null,
  124. proxyFallbackToDirect: false,
  125. firstByteTimeoutStreamingMs: 100,
  126. streamingIdleTimeoutMs: 0,
  127. requestTimeoutNonStreamingMs: 0,
  128. websiteUrl: null,
  129. faviconUrl: null,
  130. cacheTtlPreference: null,
  131. context1mPreference: null,
  132. codexReasoningEffortPreference: null,
  133. codexReasoningSummaryPreference: null,
  134. codexTextVerbosityPreference: null,
  135. codexParallelToolCallsPreference: null,
  136. codexServiceTierPreference: null,
  137. anthropicMaxTokensPreference: null,
  138. anthropicThinkingBudgetPreference: null,
  139. anthropicAdaptiveThinking: null,
  140. geminiGoogleSearchPreference: null,
  141. tpm: 0,
  142. rpm: 0,
  143. rpd: 0,
  144. cc: 0,
  145. createdAt: new Date(),
  146. updatedAt: new Date(),
  147. deletedAt: null,
  148. ...overrides,
  149. };
  150. }
  151. function createSession(clientAbortSignal: AbortSignal | null = null): ProxySession {
  152. const headers = new Headers();
  153. const session = Object.create(ProxySession.prototype);
  154. Object.assign(session, {
  155. startTime: Date.now(),
  156. method: "POST",
  157. requestUrl: new URL("https://example.com/v1/messages"),
  158. headers,
  159. originalHeaders: new Headers(headers),
  160. headerLog: JSON.stringify(Object.fromEntries(headers.entries())),
  161. request: {
  162. model: "claude-test",
  163. log: "(test)",
  164. message: {
  165. model: "claude-test",
  166. stream: true,
  167. messages: [{ role: "user", content: "hi" }],
  168. },
  169. },
  170. userAgent: null,
  171. context: null,
  172. clientAbortSignal,
  173. userName: "test-user",
  174. authState: { success: true, user: null, key: null, apiKey: null },
  175. provider: null,
  176. messageContext: null,
  177. sessionId: "sess-hedge",
  178. requestSequence: 1,
  179. originalFormat: "claude",
  180. providerType: null,
  181. originalModelName: null,
  182. originalUrlPathname: null,
  183. providerChain: [],
  184. cacheTtlResolved: null,
  185. context1mApplied: false,
  186. specialSettings: [],
  187. cachedPriceData: undefined,
  188. cachedBillingModelSource: undefined,
  189. endpointPolicy: resolveEndpointPolicy("/v1/messages"),
  190. isHeaderModified: () => false,
  191. });
  192. return session as ProxySession;
  193. }
  194. function createStreamingResponse(params: {
  195. label: string;
  196. firstChunkDelayMs: number;
  197. controller: AbortController;
  198. }): Response {
  199. const encoder = new TextEncoder();
  200. let timeoutId: ReturnType<typeof setTimeout> | null = null;
  201. const stream = new ReadableStream<Uint8Array>({
  202. start(controller) {
  203. const onAbort = () => {
  204. if (timeoutId) {
  205. clearTimeout(timeoutId);
  206. }
  207. controller.close();
  208. };
  209. if (params.controller.signal.aborted) {
  210. onAbort();
  211. return;
  212. }
  213. params.controller.signal.addEventListener("abort", onAbort, { once: true });
  214. timeoutId = setTimeout(() => {
  215. if (params.controller.signal.aborted) {
  216. controller.close();
  217. return;
  218. }
  219. controller.enqueue(encoder.encode(`data: {"provider":"${params.label}"}\n\n`));
  220. controller.close();
  221. }, params.firstChunkDelayMs);
  222. },
  223. });
  224. return new Response(stream, {
  225. status: 200,
  226. headers: { "content-type": "text/event-stream" },
  227. });
  228. }
  229. function createDelayedFailure(params: {
  230. delayMs: number;
  231. error: Error;
  232. controller: AbortController;
  233. }): Promise<Response> {
  234. return new Promise((_, reject) => {
  235. let timeoutId: ReturnType<typeof setTimeout> | null = null;
  236. const rejectWithError = () => {
  237. if (timeoutId) {
  238. clearTimeout(timeoutId);
  239. }
  240. reject(params.error);
  241. };
  242. if (params.controller.signal.aborted) {
  243. rejectWithError();
  244. return;
  245. }
  246. params.controller.signal.addEventListener("abort", rejectWithError, { once: true });
  247. timeoutId = setTimeout(() => {
  248. params.controller.signal.removeEventListener("abort", rejectWithError);
  249. reject(params.error);
  250. }, params.delayMs);
  251. });
  252. }
  253. describe("ProxyForwarder - first-byte hedge scheduling", () => {
  254. beforeEach(() => {
  255. vi.clearAllMocks();
  256. });
  257. test("first provider exceeds first-byte threshold, second provider starts and wins by first chunk", async () => {
  258. vi.useFakeTimers();
  259. try {
  260. const provider1 = createProvider({ id: 1, name: "p1", firstByteTimeoutStreamingMs: 100 });
  261. const provider2 = createProvider({ id: 2, name: "p2", firstByteTimeoutStreamingMs: 100 });
  262. const session = createSession();
  263. session.setProvider(provider1);
  264. mocks.pickRandomProviderWithExclusion.mockResolvedValueOnce(provider2);
  265. const doForward = vi.spyOn(
  266. ProxyForwarder as unknown as {
  267. doForward: (...args: unknown[]) => Promise<Response>;
  268. },
  269. "doForward"
  270. );
  271. const controller1 = new AbortController();
  272. const controller2 = new AbortController();
  273. doForward.mockImplementationOnce(async (attemptSession) => {
  274. const runtime = attemptSession as ProxySession & AttemptRuntime;
  275. runtime.responseController = controller1;
  276. runtime.clearResponseTimeout = vi.fn();
  277. return createStreamingResponse({
  278. label: "p1",
  279. firstChunkDelayMs: 220,
  280. controller: controller1,
  281. });
  282. });
  283. doForward.mockImplementationOnce(async (attemptSession) => {
  284. const runtime = attemptSession as ProxySession & AttemptRuntime;
  285. runtime.responseController = controller2;
  286. runtime.clearResponseTimeout = vi.fn();
  287. return createStreamingResponse({
  288. label: "p2",
  289. firstChunkDelayMs: 40,
  290. controller: controller2,
  291. });
  292. });
  293. const responsePromise = ProxyForwarder.send(session);
  294. await vi.advanceTimersByTimeAsync(100);
  295. expect(doForward).toHaveBeenCalledTimes(2);
  296. await vi.advanceTimersByTimeAsync(50);
  297. const response = await responsePromise;
  298. expect(await response.text()).toContain('"provider":"p2"');
  299. expect(controller1.signal.aborted).toBe(true);
  300. expect(controller2.signal.aborted).toBe(false);
  301. expect(mocks.recordFailure).not.toHaveBeenCalled();
  302. expect(mocks.recordSuccess).not.toHaveBeenCalled();
  303. expect(session.provider?.id).toBe(2);
  304. expect(mocks.updateSessionBindingSmart).toHaveBeenCalledWith("sess-hedge", 2, 0, false, true);
  305. } finally {
  306. vi.useRealTimers();
  307. }
  308. });
  309. test("characterization: hedge still launches alternative provider when maxRetryAttempts > 1", async () => {
  310. vi.useFakeTimers();
  311. try {
  312. const provider1 = createProvider({
  313. id: 1,
  314. name: "p1",
  315. maxRetryAttempts: 3,
  316. firstByteTimeoutStreamingMs: 100,
  317. });
  318. const provider2 = createProvider({
  319. id: 2,
  320. name: "p2",
  321. maxRetryAttempts: 3,
  322. firstByteTimeoutStreamingMs: 100,
  323. });
  324. const session = createSession();
  325. session.setProvider(provider1);
  326. mocks.pickRandomProviderWithExclusion.mockResolvedValueOnce(provider2);
  327. const doForward = vi.spyOn(
  328. ProxyForwarder as unknown as {
  329. doForward: (...args: unknown[]) => Promise<Response>;
  330. },
  331. "doForward"
  332. );
  333. const controller1 = new AbortController();
  334. const controller2 = new AbortController();
  335. doForward.mockImplementationOnce(async (attemptSession) => {
  336. const runtime = attemptSession as ProxySession & AttemptRuntime;
  337. runtime.responseController = controller1;
  338. runtime.clearResponseTimeout = vi.fn();
  339. return createStreamingResponse({
  340. label: "p1",
  341. firstChunkDelayMs: 220,
  342. controller: controller1,
  343. });
  344. });
  345. doForward.mockImplementationOnce(async (attemptSession) => {
  346. const runtime = attemptSession as ProxySession & AttemptRuntime;
  347. runtime.responseController = controller2;
  348. runtime.clearResponseTimeout = vi.fn();
  349. return createStreamingResponse({
  350. label: "p2",
  351. firstChunkDelayMs: 40,
  352. controller: controller2,
  353. });
  354. });
  355. const responsePromise = ProxyForwarder.send(session);
  356. await vi.advanceTimersByTimeAsync(100);
  357. expect(doForward).toHaveBeenCalledTimes(2);
  358. expect(mocks.pickRandomProviderWithExclusion).toHaveBeenCalledTimes(1);
  359. const chainBeforeWinner = session.getProviderChain();
  360. expect(chainBeforeWinner).toEqual(
  361. expect.arrayContaining([
  362. expect.objectContaining({ reason: "hedge_triggered", id: 1 }),
  363. expect.objectContaining({ reason: "hedge_launched", id: 2 }),
  364. ])
  365. );
  366. await vi.advanceTimersByTimeAsync(50);
  367. const response = await responsePromise;
  368. expect(await response.text()).toContain('"provider":"p2"');
  369. expect(controller1.signal.aborted).toBe(true);
  370. expect(session.provider?.id).toBe(2);
  371. } finally {
  372. vi.useRealTimers();
  373. }
  374. });
  375. test("first provider can still win after hedge started if it emits first chunk earlier than fallback", async () => {
  376. vi.useFakeTimers();
  377. try {
  378. const provider1 = createProvider({ id: 1, name: "p1", firstByteTimeoutStreamingMs: 100 });
  379. const provider2 = createProvider({ id: 2, name: "p2", firstByteTimeoutStreamingMs: 100 });
  380. const session = createSession();
  381. session.setProvider(provider1);
  382. mocks.pickRandomProviderWithExclusion.mockResolvedValueOnce(provider2);
  383. const doForward = vi.spyOn(
  384. ProxyForwarder as unknown as {
  385. doForward: (...args: unknown[]) => Promise<Response>;
  386. },
  387. "doForward"
  388. );
  389. const controller1 = new AbortController();
  390. const controller2 = new AbortController();
  391. doForward.mockImplementationOnce(async (attemptSession) => {
  392. const runtime = attemptSession as ProxySession & AttemptRuntime;
  393. runtime.responseController = controller1;
  394. runtime.clearResponseTimeout = vi.fn();
  395. return createStreamingResponse({
  396. label: "p1",
  397. firstChunkDelayMs: 140,
  398. controller: controller1,
  399. });
  400. });
  401. doForward.mockImplementationOnce(async (attemptSession) => {
  402. const runtime = attemptSession as ProxySession & AttemptRuntime;
  403. runtime.responseController = controller2;
  404. runtime.clearResponseTimeout = vi.fn();
  405. return createStreamingResponse({
  406. label: "p2",
  407. firstChunkDelayMs: 120,
  408. controller: controller2,
  409. });
  410. });
  411. const responsePromise = ProxyForwarder.send(session);
  412. await vi.advanceTimersByTimeAsync(100);
  413. expect(doForward).toHaveBeenCalledTimes(2);
  414. await vi.advanceTimersByTimeAsync(45);
  415. const response = await responsePromise;
  416. expect(await response.text()).toContain('"provider":"p1"');
  417. expect(controller1.signal.aborted).toBe(false);
  418. expect(controller2.signal.aborted).toBe(true);
  419. expect(mocks.recordFailure).not.toHaveBeenCalled();
  420. expect(mocks.recordSuccess).not.toHaveBeenCalled();
  421. expect(session.provider?.id).toBe(1);
  422. } finally {
  423. vi.useRealTimers();
  424. }
  425. });
  426. test("when multiple providers all exceed threshold, hedge scheduler keeps expanding until a later provider wins", async () => {
  427. vi.useFakeTimers();
  428. try {
  429. const provider1 = createProvider({ id: 1, name: "p1", firstByteTimeoutStreamingMs: 100 });
  430. const provider2 = createProvider({ id: 2, name: "p2", firstByteTimeoutStreamingMs: 100 });
  431. const provider3 = createProvider({ id: 3, name: "p3", firstByteTimeoutStreamingMs: 100 });
  432. const session = createSession();
  433. session.setProvider(provider1);
  434. mocks.pickRandomProviderWithExclusion
  435. .mockResolvedValueOnce(provider2)
  436. .mockResolvedValueOnce(provider3);
  437. const doForward = vi.spyOn(
  438. ProxyForwarder as unknown as {
  439. doForward: (...args: unknown[]) => Promise<Response>;
  440. },
  441. "doForward"
  442. );
  443. const controller1 = new AbortController();
  444. const controller2 = new AbortController();
  445. const controller3 = new AbortController();
  446. doForward.mockImplementationOnce(async (attemptSession) => {
  447. const runtime = attemptSession as ProxySession & AttemptRuntime;
  448. runtime.responseController = controller1;
  449. runtime.clearResponseTimeout = vi.fn();
  450. return createStreamingResponse({
  451. label: "p1",
  452. firstChunkDelayMs: 400,
  453. controller: controller1,
  454. });
  455. });
  456. doForward.mockImplementationOnce(async (attemptSession) => {
  457. const runtime = attemptSession as ProxySession & AttemptRuntime;
  458. runtime.responseController = controller2;
  459. runtime.clearResponseTimeout = vi.fn();
  460. return createStreamingResponse({
  461. label: "p2",
  462. firstChunkDelayMs: 400,
  463. controller: controller2,
  464. });
  465. });
  466. doForward.mockImplementationOnce(async (attemptSession) => {
  467. const runtime = attemptSession as ProxySession & AttemptRuntime;
  468. runtime.responseController = controller3;
  469. runtime.clearResponseTimeout = vi.fn();
  470. return createStreamingResponse({
  471. label: "p3",
  472. firstChunkDelayMs: 20,
  473. controller: controller3,
  474. });
  475. });
  476. const responsePromise = ProxyForwarder.send(session);
  477. await vi.advanceTimersByTimeAsync(200);
  478. expect(doForward).toHaveBeenCalledTimes(3);
  479. await vi.advanceTimersByTimeAsync(25);
  480. const response = await responsePromise;
  481. expect(await response.text()).toContain('"provider":"p3"');
  482. expect(controller1.signal.aborted).toBe(true);
  483. expect(controller2.signal.aborted).toBe(true);
  484. expect(controller3.signal.aborted).toBe(false);
  485. expect(mocks.recordFailure).not.toHaveBeenCalled();
  486. expect(mocks.recordSuccess).not.toHaveBeenCalled();
  487. expect(session.provider?.id).toBe(3);
  488. } finally {
  489. vi.useRealTimers();
  490. }
  491. });
  492. test("client abort before any winner should abort all in-flight attempts, return 499, and clear sticky provider binding", async () => {
  493. vi.useFakeTimers();
  494. try {
  495. const provider1 = createProvider({ id: 1, name: "p1", firstByteTimeoutStreamingMs: 100 });
  496. const provider2 = createProvider({ id: 2, name: "p2", firstByteTimeoutStreamingMs: 100 });
  497. const clientAbortController = new AbortController();
  498. const session = createSession(clientAbortController.signal);
  499. session.setProvider(provider1);
  500. mocks.pickRandomProviderWithExclusion.mockResolvedValueOnce(provider2);
  501. const doForward = vi.spyOn(
  502. ProxyForwarder as unknown as {
  503. doForward: (...args: unknown[]) => Promise<Response>;
  504. },
  505. "doForward"
  506. );
  507. const controller1 = new AbortController();
  508. const controller2 = new AbortController();
  509. doForward.mockImplementationOnce(async (attemptSession) => {
  510. const runtime = attemptSession as ProxySession & AttemptRuntime;
  511. runtime.responseController = controller1;
  512. runtime.clearResponseTimeout = vi.fn();
  513. return createStreamingResponse({
  514. label: "p1",
  515. firstChunkDelayMs: 500,
  516. controller: controller1,
  517. });
  518. });
  519. doForward.mockImplementationOnce(async (attemptSession) => {
  520. const runtime = attemptSession as ProxySession & AttemptRuntime;
  521. runtime.responseController = controller2;
  522. runtime.clearResponseTimeout = vi.fn();
  523. return createStreamingResponse({
  524. label: "p2",
  525. firstChunkDelayMs: 500,
  526. controller: controller2,
  527. });
  528. });
  529. const responsePromise = ProxyForwarder.send(session);
  530. const rejection = expect(responsePromise).rejects.toMatchObject({
  531. statusCode: 499,
  532. });
  533. await vi.advanceTimersByTimeAsync(100);
  534. expect(doForward).toHaveBeenCalledTimes(2);
  535. clientAbortController.abort(new Error("client_cancelled"));
  536. await vi.runAllTimersAsync();
  537. await rejection;
  538. expect(controller1.signal.aborted).toBe(true);
  539. expect(controller2.signal.aborted).toBe(true);
  540. expect(mocks.clearSessionProvider).toHaveBeenCalledWith("sess-hedge");
  541. expect(mocks.recordFailure).not.toHaveBeenCalled();
  542. expect(mocks.recordSuccess).not.toHaveBeenCalled();
  543. } finally {
  544. vi.useRealTimers();
  545. }
  546. });
  547. test("hedge launcher rejection should settle request instead of hanging", async () => {
  548. vi.useFakeTimers();
  549. try {
  550. const provider1 = createProvider({ id: 1, name: "p1", firstByteTimeoutStreamingMs: 100 });
  551. const session = createSession();
  552. session.setProvider(provider1);
  553. mocks.pickRandomProviderWithExclusion.mockRejectedValueOnce(new Error("selector down"));
  554. const doForward = vi.spyOn(
  555. ProxyForwarder as unknown as {
  556. doForward: (...args: unknown[]) => Promise<Response>;
  557. },
  558. "doForward"
  559. );
  560. const controller1 = new AbortController();
  561. doForward.mockImplementationOnce(async (attemptSession) => {
  562. const runtime = attemptSession as ProxySession & AttemptRuntime;
  563. runtime.responseController = controller1;
  564. runtime.clearResponseTimeout = vi.fn();
  565. return createStreamingResponse({
  566. label: "p1",
  567. firstChunkDelayMs: 500,
  568. controller: controller1,
  569. });
  570. });
  571. const responsePromise = ProxyForwarder.send(session);
  572. const rejection = expect(responsePromise).rejects.toMatchObject({
  573. statusCode: 503,
  574. });
  575. await vi.advanceTimersByTimeAsync(100);
  576. await vi.runAllTimersAsync();
  577. await rejection;
  578. expect(controller1.signal.aborted).toBe(true);
  579. } finally {
  580. vi.useRealTimers();
  581. }
  582. });
  583. test("strict endpoint pool exhaustion should converge to terminal fallback instead of provider-specific error", async () => {
  584. vi.useFakeTimers();
  585. try {
  586. const provider1 = createProvider({
  587. id: 1,
  588. name: "p1",
  589. providerType: "claude",
  590. providerVendorId: 123,
  591. firstByteTimeoutStreamingMs: 100,
  592. });
  593. const session = createSession();
  594. session.requestUrl = new URL("https://example.com/v1/messages");
  595. session.setProvider(provider1);
  596. mocks.getPreferredProviderEndpoints.mockRejectedValueOnce(new Error("Redis connection lost"));
  597. mocks.pickRandomProviderWithExclusion.mockResolvedValueOnce(null);
  598. const responsePromise = ProxyForwarder.send(session);
  599. const errorPromise = responsePromise.catch((rejection) => rejection as UpstreamProxyError);
  600. await vi.runAllTimersAsync();
  601. const error = await errorPromise;
  602. expect(mocks.pickRandomProviderWithExclusion).toHaveBeenCalled();
  603. expect(error).toBeInstanceOf(UpstreamProxyError);
  604. expect(error.statusCode).toBe(503);
  605. expect(error.message).toBe("所有供应商暂时不可用,请稍后重试");
  606. } finally {
  607. vi.useRealTimers();
  608. }
  609. });
  610. test.each([
  611. {
  612. name: "provider error",
  613. category: ProxyErrorCategory.PROVIDER_ERROR,
  614. errorFactory: (provider: Provider) =>
  615. new UpstreamProxyError("Provider returned 401: invalid key", 401, {
  616. body: '{"error":"invalid_api_key"}',
  617. providerId: provider.id,
  618. providerName: provider.name,
  619. }),
  620. },
  621. {
  622. name: "resource not found",
  623. category: ProxyErrorCategory.RESOURCE_NOT_FOUND,
  624. errorFactory: (provider: Provider) =>
  625. new UpstreamProxyError("Provider returned 404: model not found", 404, {
  626. body: '{"error":"model_not_found"}',
  627. providerId: provider.id,
  628. providerName: provider.name,
  629. }),
  630. },
  631. {
  632. name: "system error",
  633. category: ProxyErrorCategory.SYSTEM_ERROR,
  634. errorFactory: () => new Error("fetch failed"),
  635. },
  636. ])("when a real hedge race ends with only $name, terminal error should be generic fallback", async ({
  637. category,
  638. errorFactory,
  639. }) => {
  640. vi.useFakeTimers();
  641. try {
  642. const provider1 = createProvider({ id: 1, name: "p1", firstByteTimeoutStreamingMs: 100 });
  643. const provider2 = createProvider({ id: 2, name: "p2", firstByteTimeoutStreamingMs: 100 });
  644. const session = createSession();
  645. session.setProvider(provider1);
  646. mocks.pickRandomProviderWithExclusion
  647. .mockResolvedValueOnce(provider2)
  648. .mockResolvedValueOnce(null);
  649. mocks.categorizeErrorAsync.mockResolvedValueOnce(category).mockResolvedValueOnce(category);
  650. const doForward = vi.spyOn(
  651. ProxyForwarder as unknown as {
  652. doForward: (...args: unknown[]) => Promise<Response>;
  653. },
  654. "doForward"
  655. );
  656. const controller1 = new AbortController();
  657. const controller2 = new AbortController();
  658. doForward.mockImplementationOnce(async (attemptSession) => {
  659. const runtime = attemptSession as ProxySession & AttemptRuntime;
  660. runtime.responseController = controller1;
  661. runtime.clearResponseTimeout = vi.fn();
  662. return createDelayedFailure({
  663. delayMs: 150,
  664. error: errorFactory(provider1),
  665. controller: controller1,
  666. });
  667. });
  668. doForward.mockImplementationOnce(async (attemptSession) => {
  669. const runtime = attemptSession as ProxySession & AttemptRuntime;
  670. runtime.responseController = controller2;
  671. runtime.clearResponseTimeout = vi.fn();
  672. return createDelayedFailure({
  673. delayMs: 160,
  674. error: errorFactory(provider2),
  675. controller: controller2,
  676. });
  677. });
  678. const responsePromise = ProxyForwarder.send(session);
  679. const errorPromise = responsePromise.catch((rejection) => rejection as UpstreamProxyError);
  680. await vi.advanceTimersByTimeAsync(100);
  681. expect(doForward).toHaveBeenCalledTimes(2);
  682. await vi.runAllTimersAsync();
  683. const error = await errorPromise;
  684. expect(error).toBeInstanceOf(UpstreamProxyError);
  685. expect(error.statusCode).toBe(503);
  686. expect(error.message).toBe("所有供应商暂时不可用,请稍后重试");
  687. expect(error.message).not.toContain("invalid key");
  688. expect(error.message).not.toContain("model not found");
  689. expect(mocks.clearSessionProvider).toHaveBeenCalledWith("sess-hedge");
  690. } finally {
  691. vi.useRealTimers();
  692. }
  693. });
  694. test("non-retryable client errors should stop hedge immediately and preserve original error", async () => {
  695. const provider1 = createProvider({ id: 1, name: "p1", firstByteTimeoutStreamingMs: 100 });
  696. const provider2 = createProvider({ id: 2, name: "p2", firstByteTimeoutStreamingMs: 100 });
  697. const session = createSession();
  698. session.setProvider(provider1);
  699. const originalError = new UpstreamProxyError("prompt too long", 400, {
  700. body: '{"error":"prompt_too_long"}',
  701. providerId: provider1.id,
  702. providerName: provider1.name,
  703. });
  704. mocks.pickRandomProviderWithExclusion.mockResolvedValueOnce(provider2);
  705. mocks.categorizeErrorAsync.mockResolvedValueOnce(ProxyErrorCategory.NON_RETRYABLE_CLIENT_ERROR);
  706. const doForward = vi.spyOn(
  707. ProxyForwarder as unknown as {
  708. doForward: (...args: unknown[]) => Promise<Response>;
  709. },
  710. "doForward"
  711. );
  712. doForward.mockRejectedValueOnce(originalError);
  713. const error = await ProxyForwarder.send(session).catch(
  714. (rejection) => rejection as UpstreamProxyError
  715. );
  716. expect(error).toBe(originalError);
  717. expect(error.message).toBe("prompt too long");
  718. expect(doForward).toHaveBeenCalledTimes(1);
  719. expect(mocks.pickRandomProviderWithExclusion).not.toHaveBeenCalled();
  720. expect(mocks.clearSessionProvider).toHaveBeenCalledWith("sess-hedge");
  721. expect(session.getProviderChain()).toEqual(
  722. expect.arrayContaining([
  723. expect.objectContaining({
  724. reason: "client_error_non_retryable",
  725. statusCode: 400,
  726. }),
  727. ])
  728. );
  729. });
  730. test("endpoint resolution failure should not inflate launchedProviderCount, winner gets request_success not hedge_winner", async () => {
  731. vi.useFakeTimers();
  732. try {
  733. const provider1 = createProvider({
  734. id: 1,
  735. name: "p1",
  736. providerVendorId: 123,
  737. firstByteTimeoutStreamingMs: 100,
  738. });
  739. const provider2 = createProvider({
  740. id: 2,
  741. name: "p2",
  742. providerVendorId: null,
  743. firstByteTimeoutStreamingMs: 100,
  744. });
  745. const session = createSession();
  746. session.requestUrl = new URL("https://example.com/v1/messages");
  747. session.setProvider(provider1);
  748. // Provider 1's strict endpoint resolution will fail
  749. mocks.getPreferredProviderEndpoints.mockRejectedValueOnce(
  750. new Error("Endpoint resolution failed")
  751. );
  752. // After provider 1 fails, pick provider 2 as alternative
  753. mocks.pickRandomProviderWithExclusion.mockResolvedValueOnce(provider2);
  754. const doForward = vi.spyOn(
  755. ProxyForwarder as unknown as {
  756. doForward: (...args: unknown[]) => Promise<Response>;
  757. },
  758. "doForward"
  759. );
  760. const controller2 = new AbortController();
  761. // Only provider 2 reaches doForward (provider 1 fails at endpoint resolution)
  762. doForward.mockImplementationOnce(async (attemptSession) => {
  763. const runtime = attemptSession as ProxySession & AttemptRuntime;
  764. runtime.responseController = controller2;
  765. runtime.clearResponseTimeout = vi.fn();
  766. return createStreamingResponse({
  767. label: "p2",
  768. firstChunkDelayMs: 10,
  769. controller: controller2,
  770. });
  771. });
  772. const responsePromise = ProxyForwarder.send(session);
  773. await vi.advanceTimersByTimeAsync(200);
  774. const response = await responsePromise;
  775. expect(await response.text()).toContain('"provider":"p2"');
  776. expect(session.provider?.id).toBe(2);
  777. // Key assertion: since only provider 2 actually launched (provider 1 failed at
  778. // endpoint resolution before incrementing launchedProviderCount), the winner
  779. // should be classified as "request_success" not "hedge_winner".
  780. const chain = session.getProviderChain();
  781. const winnerEntry = chain.find(
  782. (entry) => entry.reason === "request_success" || entry.reason === "hedge_winner"
  783. );
  784. expect(winnerEntry).toBeDefined();
  785. expect(winnerEntry!.reason).toBe("request_success");
  786. } finally {
  787. vi.useRealTimers();
  788. }
  789. });
  790. });