proxy-forwarder-hedge-first-byte.test.ts 38 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154
  1. import { beforeEach, describe, expect, test, vi } from "vitest";
  2. import { resolveEndpointPolicy } from "@/app/v1/_lib/proxy/endpoint-policy";
  3. const mocks = vi.hoisted(() => ({
  4. pickRandomProviderWithExclusion: vi.fn(),
  5. recordSuccess: vi.fn(),
  6. recordFailure: vi.fn(async () => {}),
  7. getCircuitState: vi.fn(() => "closed"),
  8. getProviderHealthInfo: vi.fn(async () => ({
  9. health: { failureCount: 0 },
  10. config: { failureThreshold: 3 },
  11. })),
  12. updateSessionBindingSmart: vi.fn(async () => ({ updated: true, reason: "test" })),
  13. updateSessionProvider: vi.fn(async () => {}),
  14. clearSessionProvider: vi.fn(async () => {}),
  15. isHttp2Enabled: vi.fn(async () => false),
  16. getPreferredProviderEndpoints: vi.fn(async () => []),
  17. getEndpointFilterStats: vi.fn(async () => null),
  18. recordEndpointSuccess: vi.fn(async () => {}),
  19. recordEndpointFailure: vi.fn(async () => {}),
  20. isVendorTypeCircuitOpen: vi.fn(async () => false),
  21. recordVendorTypeAllEndpointsTimeout: vi.fn(async () => {}),
  22. categorizeErrorAsync: vi.fn(async () => 0),
  23. getCachedSystemSettings: vi.fn(async () => ({
  24. enableThinkingSignatureRectifier: true,
  25. enableThinkingBudgetRectifier: true,
  26. })),
  27. storeSessionSpecialSettings: vi.fn(async () => {}),
  28. }));
  29. vi.mock("@/lib/logger", () => ({
  30. logger: {
  31. debug: vi.fn(),
  32. info: vi.fn(),
  33. warn: vi.fn(),
  34. trace: vi.fn(),
  35. error: vi.fn(),
  36. fatal: vi.fn(),
  37. },
  38. }));
  39. vi.mock("@/lib/config", async (importOriginal) => {
  40. const actual = await importOriginal<typeof import("@/lib/config")>();
  41. return {
  42. ...actual,
  43. getCachedSystemSettings: mocks.getCachedSystemSettings,
  44. isHttp2Enabled: mocks.isHttp2Enabled,
  45. };
  46. });
  47. vi.mock("@/lib/provider-endpoints/endpoint-selector", () => ({
  48. getPreferredProviderEndpoints: mocks.getPreferredProviderEndpoints,
  49. getEndpointFilterStats: mocks.getEndpointFilterStats,
  50. }));
  51. vi.mock("@/lib/endpoint-circuit-breaker", () => ({
  52. recordEndpointSuccess: mocks.recordEndpointSuccess,
  53. recordEndpointFailure: mocks.recordEndpointFailure,
  54. }));
  55. vi.mock("@/lib/circuit-breaker", () => ({
  56. getCircuitState: mocks.getCircuitState,
  57. getProviderHealthInfo: mocks.getProviderHealthInfo,
  58. recordFailure: mocks.recordFailure,
  59. recordSuccess: mocks.recordSuccess,
  60. }));
  61. vi.mock("@/lib/vendor-type-circuit-breaker", () => ({
  62. isVendorTypeCircuitOpen: mocks.isVendorTypeCircuitOpen,
  63. recordVendorTypeAllEndpointsTimeout: mocks.recordVendorTypeAllEndpointsTimeout,
  64. }));
  65. vi.mock("@/lib/session-manager", () => ({
  66. SessionManager: {
  67. updateSessionBindingSmart: mocks.updateSessionBindingSmart,
  68. updateSessionProvider: mocks.updateSessionProvider,
  69. clearSessionProvider: mocks.clearSessionProvider,
  70. storeSessionSpecialSettings: mocks.storeSessionSpecialSettings,
  71. },
  72. }));
  73. vi.mock("@/app/v1/_lib/proxy/provider-selector", () => ({
  74. ProxyProviderResolver: {
  75. pickRandomProviderWithExclusion: mocks.pickRandomProviderWithExclusion,
  76. },
  77. }));
  78. vi.mock("@/app/v1/_lib/proxy/errors", async (importOriginal) => {
  79. const actual = await importOriginal<typeof import("@/app/v1/_lib/proxy/errors")>();
  80. return {
  81. ...actual,
  82. categorizeErrorAsync: mocks.categorizeErrorAsync,
  83. };
  84. });
  85. import {
  86. ErrorCategory as ProxyErrorCategory,
  87. ProxyError as UpstreamProxyError,
  88. } from "@/app/v1/_lib/proxy/errors";
  89. import { ProxyForwarder } from "@/app/v1/_lib/proxy/forwarder";
  90. import { ProxySession } from "@/app/v1/_lib/proxy/session";
  91. import type { Provider } from "@/types/provider";
  92. type AttemptRuntime = {
  93. clearResponseTimeout?: () => void;
  94. responseController?: AbortController;
  95. };
  96. function createProvider(overrides: Partial<Provider> = {}): Provider {
  97. return {
  98. id: 1,
  99. name: "p1",
  100. url: "https://provider.example.com",
  101. key: "k",
  102. providerVendorId: null,
  103. isEnabled: true,
  104. weight: 1,
  105. priority: 0,
  106. groupPriorities: null,
  107. costMultiplier: 1,
  108. groupTag: null,
  109. providerType: "claude",
  110. preserveClientIp: false,
  111. modelRedirects: null,
  112. allowedModels: null,
  113. mcpPassthroughType: "none",
  114. mcpPassthroughUrl: null,
  115. limit5hUsd: null,
  116. limitDailyUsd: null,
  117. dailyResetMode: "fixed",
  118. dailyResetTime: "00:00",
  119. limitWeeklyUsd: null,
  120. limitMonthlyUsd: null,
  121. limitTotalUsd: null,
  122. totalCostResetAt: null,
  123. limitConcurrentSessions: 0,
  124. maxRetryAttempts: 1,
  125. circuitBreakerFailureThreshold: 5,
  126. circuitBreakerOpenDuration: 1_800_000,
  127. circuitBreakerHalfOpenSuccessThreshold: 2,
  128. proxyUrl: null,
  129. proxyFallbackToDirect: false,
  130. firstByteTimeoutStreamingMs: 100,
  131. streamingIdleTimeoutMs: 0,
  132. requestTimeoutNonStreamingMs: 0,
  133. websiteUrl: null,
  134. faviconUrl: null,
  135. cacheTtlPreference: null,
  136. context1mPreference: null,
  137. codexReasoningEffortPreference: null,
  138. codexReasoningSummaryPreference: null,
  139. codexTextVerbosityPreference: null,
  140. codexParallelToolCallsPreference: null,
  141. codexServiceTierPreference: null,
  142. anthropicMaxTokensPreference: null,
  143. anthropicThinkingBudgetPreference: null,
  144. anthropicAdaptiveThinking: null,
  145. geminiGoogleSearchPreference: null,
  146. tpm: 0,
  147. rpm: 0,
  148. rpd: 0,
  149. cc: 0,
  150. createdAt: new Date(),
  151. updatedAt: new Date(),
  152. deletedAt: null,
  153. ...overrides,
  154. };
  155. }
  156. function createSession(clientAbortSignal: AbortSignal | null = null): ProxySession {
  157. const headers = new Headers();
  158. const session = Object.create(ProxySession.prototype);
  159. Object.assign(session, {
  160. startTime: Date.now(),
  161. method: "POST",
  162. requestUrl: new URL("https://example.com/v1/messages"),
  163. headers,
  164. originalHeaders: new Headers(headers),
  165. headerLog: JSON.stringify(Object.fromEntries(headers.entries())),
  166. request: {
  167. model: "claude-test",
  168. log: "(test)",
  169. message: {
  170. model: "claude-test",
  171. stream: true,
  172. messages: [{ role: "user", content: "hi" }],
  173. },
  174. },
  175. userAgent: null,
  176. context: null,
  177. clientAbortSignal,
  178. userName: "test-user",
  179. authState: { success: true, user: null, key: null, apiKey: null },
  180. provider: null,
  181. messageContext: null,
  182. sessionId: "sess-hedge",
  183. requestSequence: 1,
  184. originalFormat: "claude",
  185. providerType: null,
  186. originalModelName: null,
  187. originalUrlPathname: null,
  188. providerChain: [],
  189. cacheTtlResolved: null,
  190. context1mApplied: false,
  191. specialSettings: [],
  192. cachedPriceData: undefined,
  193. cachedBillingModelSource: undefined,
  194. endpointPolicy: resolveEndpointPolicy("/v1/messages"),
  195. isHeaderModified: () => false,
  196. });
  197. return session as ProxySession;
  198. }
  199. function createStreamingResponse(params: {
  200. label: string;
  201. firstChunkDelayMs: number;
  202. controller: AbortController;
  203. }): Response {
  204. const encoder = new TextEncoder();
  205. let timeoutId: ReturnType<typeof setTimeout> | null = null;
  206. const stream = new ReadableStream<Uint8Array>({
  207. start(controller) {
  208. const onAbort = () => {
  209. if (timeoutId) {
  210. clearTimeout(timeoutId);
  211. }
  212. controller.close();
  213. };
  214. if (params.controller.signal.aborted) {
  215. onAbort();
  216. return;
  217. }
  218. params.controller.signal.addEventListener("abort", onAbort, { once: true });
  219. timeoutId = setTimeout(() => {
  220. if (params.controller.signal.aborted) {
  221. controller.close();
  222. return;
  223. }
  224. controller.enqueue(encoder.encode(`data: {"provider":"${params.label}"}\n\n`));
  225. controller.close();
  226. }, params.firstChunkDelayMs);
  227. },
  228. });
  229. return new Response(stream, {
  230. status: 200,
  231. headers: { "content-type": "text/event-stream" },
  232. });
  233. }
  234. function createDelayedFailure(params: {
  235. delayMs: number;
  236. error: Error;
  237. controller: AbortController;
  238. }): Promise<Response> {
  239. return new Promise((_, reject) => {
  240. let timeoutId: ReturnType<typeof setTimeout> | null = null;
  241. const rejectWithError = () => {
  242. if (timeoutId) {
  243. clearTimeout(timeoutId);
  244. }
  245. reject(params.error);
  246. };
  247. if (params.controller.signal.aborted) {
  248. rejectWithError();
  249. return;
  250. }
  251. params.controller.signal.addEventListener("abort", rejectWithError, { once: true });
  252. timeoutId = setTimeout(() => {
  253. params.controller.signal.removeEventListener("abort", rejectWithError);
  254. reject(params.error);
  255. }, params.delayMs);
  256. });
  257. }
  258. function withThinkingBlocks(session: ProxySession): void {
  259. session.request.message = {
  260. model: "claude-test",
  261. stream: true,
  262. messages: [
  263. {
  264. role: "assistant",
  265. content: [
  266. { type: "thinking", thinking: "t", signature: "sig_thinking" },
  267. { type: "text", text: "hello", signature: "sig_text_should_remove" },
  268. { type: "redacted_thinking", data: "r", signature: "sig_redacted" },
  269. ],
  270. },
  271. ],
  272. };
  273. }
  274. describe("ProxyForwarder - first-byte hedge scheduling", () => {
  275. beforeEach(() => {
  276. vi.clearAllMocks();
  277. });
  278. test("first provider exceeds first-byte threshold, second provider starts and wins by first chunk", async () => {
  279. vi.useFakeTimers();
  280. try {
  281. const provider1 = createProvider({ id: 1, name: "p1", firstByteTimeoutStreamingMs: 100 });
  282. const provider2 = createProvider({ id: 2, name: "p2", firstByteTimeoutStreamingMs: 100 });
  283. const session = createSession();
  284. session.setProvider(provider1);
  285. mocks.pickRandomProviderWithExclusion.mockResolvedValueOnce(provider2);
  286. const doForward = vi.spyOn(
  287. ProxyForwarder as unknown as {
  288. doForward: (...args: unknown[]) => Promise<Response>;
  289. },
  290. "doForward"
  291. );
  292. const controller1 = new AbortController();
  293. const controller2 = new AbortController();
  294. doForward.mockImplementationOnce(async (attemptSession) => {
  295. const runtime = attemptSession as ProxySession & AttemptRuntime;
  296. runtime.responseController = controller1;
  297. runtime.clearResponseTimeout = vi.fn();
  298. return createStreamingResponse({
  299. label: "p1",
  300. firstChunkDelayMs: 220,
  301. controller: controller1,
  302. });
  303. });
  304. doForward.mockImplementationOnce(async (attemptSession) => {
  305. const runtime = attemptSession as ProxySession & AttemptRuntime;
  306. runtime.responseController = controller2;
  307. runtime.clearResponseTimeout = vi.fn();
  308. return createStreamingResponse({
  309. label: "p2",
  310. firstChunkDelayMs: 40,
  311. controller: controller2,
  312. });
  313. });
  314. const responsePromise = ProxyForwarder.send(session);
  315. await vi.advanceTimersByTimeAsync(100);
  316. expect(doForward).toHaveBeenCalledTimes(2);
  317. await vi.advanceTimersByTimeAsync(50);
  318. const response = await responsePromise;
  319. expect(await response.text()).toContain('"provider":"p2"');
  320. expect(controller1.signal.aborted).toBe(true);
  321. expect(controller2.signal.aborted).toBe(false);
  322. expect(mocks.recordFailure).not.toHaveBeenCalled();
  323. expect(mocks.recordSuccess).not.toHaveBeenCalled();
  324. expect(session.provider?.id).toBe(2);
  325. expect(mocks.updateSessionBindingSmart).toHaveBeenCalledWith("sess-hedge", 2, 0, false, true);
  326. } finally {
  327. vi.useRealTimers();
  328. }
  329. });
  330. test("characterization: hedge still launches alternative provider when maxRetryAttempts > 1", async () => {
  331. vi.useFakeTimers();
  332. try {
  333. const provider1 = createProvider({
  334. id: 1,
  335. name: "p1",
  336. maxRetryAttempts: 3,
  337. firstByteTimeoutStreamingMs: 100,
  338. });
  339. const provider2 = createProvider({
  340. id: 2,
  341. name: "p2",
  342. maxRetryAttempts: 3,
  343. firstByteTimeoutStreamingMs: 100,
  344. });
  345. const session = createSession();
  346. session.setProvider(provider1);
  347. mocks.pickRandomProviderWithExclusion.mockResolvedValueOnce(provider2);
  348. const doForward = vi.spyOn(
  349. ProxyForwarder as unknown as {
  350. doForward: (...args: unknown[]) => Promise<Response>;
  351. },
  352. "doForward"
  353. );
  354. const controller1 = new AbortController();
  355. const controller2 = new AbortController();
  356. doForward.mockImplementationOnce(async (attemptSession) => {
  357. const runtime = attemptSession as ProxySession & AttemptRuntime;
  358. runtime.responseController = controller1;
  359. runtime.clearResponseTimeout = vi.fn();
  360. return createStreamingResponse({
  361. label: "p1",
  362. firstChunkDelayMs: 220,
  363. controller: controller1,
  364. });
  365. });
  366. doForward.mockImplementationOnce(async (attemptSession) => {
  367. const runtime = attemptSession as ProxySession & AttemptRuntime;
  368. runtime.responseController = controller2;
  369. runtime.clearResponseTimeout = vi.fn();
  370. return createStreamingResponse({
  371. label: "p2",
  372. firstChunkDelayMs: 40,
  373. controller: controller2,
  374. });
  375. });
  376. const responsePromise = ProxyForwarder.send(session);
  377. await vi.advanceTimersByTimeAsync(100);
  378. expect(doForward).toHaveBeenCalledTimes(2);
  379. expect(mocks.pickRandomProviderWithExclusion).toHaveBeenCalledTimes(1);
  380. const chainBeforeWinner = session.getProviderChain();
  381. expect(chainBeforeWinner).toEqual(
  382. expect.arrayContaining([
  383. expect.objectContaining({ reason: "hedge_triggered", id: 1 }),
  384. expect.objectContaining({ reason: "hedge_launched", id: 2 }),
  385. ])
  386. );
  387. await vi.advanceTimersByTimeAsync(50);
  388. const response = await responsePromise;
  389. expect(await response.text()).toContain('"provider":"p2"');
  390. expect(controller1.signal.aborted).toBe(true);
  391. expect(session.provider?.id).toBe(2);
  392. } finally {
  393. vi.useRealTimers();
  394. }
  395. });
  396. test("first provider can still win after hedge started if it emits first chunk earlier than fallback", async () => {
  397. vi.useFakeTimers();
  398. try {
  399. const provider1 = createProvider({ id: 1, name: "p1", firstByteTimeoutStreamingMs: 100 });
  400. const provider2 = createProvider({ id: 2, name: "p2", firstByteTimeoutStreamingMs: 100 });
  401. const session = createSession();
  402. session.setProvider(provider1);
  403. mocks.pickRandomProviderWithExclusion.mockResolvedValueOnce(provider2);
  404. const doForward = vi.spyOn(
  405. ProxyForwarder as unknown as {
  406. doForward: (...args: unknown[]) => Promise<Response>;
  407. },
  408. "doForward"
  409. );
  410. const controller1 = new AbortController();
  411. const controller2 = new AbortController();
  412. doForward.mockImplementationOnce(async (attemptSession) => {
  413. const runtime = attemptSession as ProxySession & AttemptRuntime;
  414. runtime.responseController = controller1;
  415. runtime.clearResponseTimeout = vi.fn();
  416. return createStreamingResponse({
  417. label: "p1",
  418. firstChunkDelayMs: 140,
  419. controller: controller1,
  420. });
  421. });
  422. doForward.mockImplementationOnce(async (attemptSession) => {
  423. const runtime = attemptSession as ProxySession & AttemptRuntime;
  424. runtime.responseController = controller2;
  425. runtime.clearResponseTimeout = vi.fn();
  426. return createStreamingResponse({
  427. label: "p2",
  428. firstChunkDelayMs: 120,
  429. controller: controller2,
  430. });
  431. });
  432. const responsePromise = ProxyForwarder.send(session);
  433. await vi.advanceTimersByTimeAsync(100);
  434. expect(doForward).toHaveBeenCalledTimes(2);
  435. await vi.advanceTimersByTimeAsync(45);
  436. const response = await responsePromise;
  437. expect(await response.text()).toContain('"provider":"p1"');
  438. expect(controller1.signal.aborted).toBe(false);
  439. expect(controller2.signal.aborted).toBe(true);
  440. expect(mocks.recordFailure).not.toHaveBeenCalled();
  441. expect(mocks.recordSuccess).not.toHaveBeenCalled();
  442. expect(session.provider?.id).toBe(1);
  443. } finally {
  444. vi.useRealTimers();
  445. }
  446. });
  447. test("when multiple providers all exceed threshold, hedge scheduler keeps expanding until a later provider wins", async () => {
  448. vi.useFakeTimers();
  449. try {
  450. const provider1 = createProvider({ id: 1, name: "p1", firstByteTimeoutStreamingMs: 100 });
  451. const provider2 = createProvider({ id: 2, name: "p2", firstByteTimeoutStreamingMs: 100 });
  452. const provider3 = createProvider({ id: 3, name: "p3", firstByteTimeoutStreamingMs: 100 });
  453. const session = createSession();
  454. session.setProvider(provider1);
  455. mocks.pickRandomProviderWithExclusion
  456. .mockResolvedValueOnce(provider2)
  457. .mockResolvedValueOnce(provider3);
  458. const doForward = vi.spyOn(
  459. ProxyForwarder as unknown as {
  460. doForward: (...args: unknown[]) => Promise<Response>;
  461. },
  462. "doForward"
  463. );
  464. const controller1 = new AbortController();
  465. const controller2 = new AbortController();
  466. const controller3 = new AbortController();
  467. doForward.mockImplementationOnce(async (attemptSession) => {
  468. const runtime = attemptSession as ProxySession & AttemptRuntime;
  469. runtime.responseController = controller1;
  470. runtime.clearResponseTimeout = vi.fn();
  471. return createStreamingResponse({
  472. label: "p1",
  473. firstChunkDelayMs: 400,
  474. controller: controller1,
  475. });
  476. });
  477. doForward.mockImplementationOnce(async (attemptSession) => {
  478. const runtime = attemptSession as ProxySession & AttemptRuntime;
  479. runtime.responseController = controller2;
  480. runtime.clearResponseTimeout = vi.fn();
  481. return createStreamingResponse({
  482. label: "p2",
  483. firstChunkDelayMs: 400,
  484. controller: controller2,
  485. });
  486. });
  487. doForward.mockImplementationOnce(async (attemptSession) => {
  488. const runtime = attemptSession as ProxySession & AttemptRuntime;
  489. runtime.responseController = controller3;
  490. runtime.clearResponseTimeout = vi.fn();
  491. return createStreamingResponse({
  492. label: "p3",
  493. firstChunkDelayMs: 20,
  494. controller: controller3,
  495. });
  496. });
  497. const responsePromise = ProxyForwarder.send(session);
  498. await vi.advanceTimersByTimeAsync(200);
  499. expect(doForward).toHaveBeenCalledTimes(3);
  500. await vi.advanceTimersByTimeAsync(25);
  501. const response = await responsePromise;
  502. expect(await response.text()).toContain('"provider":"p3"');
  503. expect(controller1.signal.aborted).toBe(true);
  504. expect(controller2.signal.aborted).toBe(true);
  505. expect(controller3.signal.aborted).toBe(false);
  506. expect(mocks.recordFailure).not.toHaveBeenCalled();
  507. expect(mocks.recordSuccess).not.toHaveBeenCalled();
  508. expect(session.provider?.id).toBe(3);
  509. } finally {
  510. vi.useRealTimers();
  511. }
  512. });
  513. test("client abort before any winner should abort all in-flight attempts, return 499, and clear sticky provider binding", async () => {
  514. vi.useFakeTimers();
  515. try {
  516. const provider1 = createProvider({ id: 1, name: "p1", firstByteTimeoutStreamingMs: 100 });
  517. const provider2 = createProvider({ id: 2, name: "p2", firstByteTimeoutStreamingMs: 100 });
  518. const clientAbortController = new AbortController();
  519. const session = createSession(clientAbortController.signal);
  520. session.setProvider(provider1);
  521. mocks.pickRandomProviderWithExclusion.mockResolvedValueOnce(provider2);
  522. const doForward = vi.spyOn(
  523. ProxyForwarder as unknown as {
  524. doForward: (...args: unknown[]) => Promise<Response>;
  525. },
  526. "doForward"
  527. );
  528. const controller1 = new AbortController();
  529. const controller2 = new AbortController();
  530. doForward.mockImplementationOnce(async (attemptSession) => {
  531. const runtime = attemptSession as ProxySession & AttemptRuntime;
  532. runtime.responseController = controller1;
  533. runtime.clearResponseTimeout = vi.fn();
  534. return createStreamingResponse({
  535. label: "p1",
  536. firstChunkDelayMs: 500,
  537. controller: controller1,
  538. });
  539. });
  540. doForward.mockImplementationOnce(async (attemptSession) => {
  541. const runtime = attemptSession as ProxySession & AttemptRuntime;
  542. runtime.responseController = controller2;
  543. runtime.clearResponseTimeout = vi.fn();
  544. return createStreamingResponse({
  545. label: "p2",
  546. firstChunkDelayMs: 500,
  547. controller: controller2,
  548. });
  549. });
  550. const responsePromise = ProxyForwarder.send(session);
  551. const rejection = expect(responsePromise).rejects.toMatchObject({
  552. statusCode: 499,
  553. });
  554. await vi.advanceTimersByTimeAsync(100);
  555. expect(doForward).toHaveBeenCalledTimes(2);
  556. clientAbortController.abort(new Error("client_cancelled"));
  557. await vi.runAllTimersAsync();
  558. await rejection;
  559. expect(controller1.signal.aborted).toBe(true);
  560. expect(controller2.signal.aborted).toBe(true);
  561. expect(mocks.clearSessionProvider).toHaveBeenCalledWith("sess-hedge");
  562. expect(mocks.recordFailure).not.toHaveBeenCalled();
  563. expect(mocks.recordSuccess).not.toHaveBeenCalled();
  564. } finally {
  565. vi.useRealTimers();
  566. }
  567. });
  568. test("hedge launcher rejection should settle request instead of hanging", async () => {
  569. vi.useFakeTimers();
  570. try {
  571. const provider1 = createProvider({ id: 1, name: "p1", firstByteTimeoutStreamingMs: 100 });
  572. const session = createSession();
  573. session.setProvider(provider1);
  574. mocks.pickRandomProviderWithExclusion.mockRejectedValueOnce(new Error("selector down"));
  575. const doForward = vi.spyOn(
  576. ProxyForwarder as unknown as {
  577. doForward: (...args: unknown[]) => Promise<Response>;
  578. },
  579. "doForward"
  580. );
  581. const controller1 = new AbortController();
  582. doForward.mockImplementationOnce(async (attemptSession) => {
  583. const runtime = attemptSession as ProxySession & AttemptRuntime;
  584. runtime.responseController = controller1;
  585. runtime.clearResponseTimeout = vi.fn();
  586. return createStreamingResponse({
  587. label: "p1",
  588. firstChunkDelayMs: 500,
  589. controller: controller1,
  590. });
  591. });
  592. const responsePromise = ProxyForwarder.send(session);
  593. const rejection = expect(responsePromise).rejects.toMatchObject({
  594. statusCode: 503,
  595. });
  596. await vi.advanceTimersByTimeAsync(100);
  597. await vi.runAllTimersAsync();
  598. await rejection;
  599. expect(controller1.signal.aborted).toBe(true);
  600. } finally {
  601. vi.useRealTimers();
  602. }
  603. });
  604. test("strict endpoint pool exhaustion should converge to terminal fallback instead of provider-specific error", async () => {
  605. vi.useFakeTimers();
  606. try {
  607. const provider1 = createProvider({
  608. id: 1,
  609. name: "p1",
  610. providerType: "claude",
  611. providerVendorId: 123,
  612. firstByteTimeoutStreamingMs: 100,
  613. });
  614. const session = createSession();
  615. session.requestUrl = new URL("https://example.com/v1/messages");
  616. session.setProvider(provider1);
  617. mocks.getPreferredProviderEndpoints.mockRejectedValueOnce(new Error("Redis connection lost"));
  618. mocks.pickRandomProviderWithExclusion.mockResolvedValueOnce(null);
  619. const responsePromise = ProxyForwarder.send(session);
  620. const errorPromise = responsePromise.catch((rejection) => rejection as UpstreamProxyError);
  621. await vi.runAllTimersAsync();
  622. const error = await errorPromise;
  623. expect(mocks.pickRandomProviderWithExclusion).toHaveBeenCalled();
  624. expect(error).toBeInstanceOf(UpstreamProxyError);
  625. expect(error.statusCode).toBe(503);
  626. expect(error.message).toBe("所有供应商暂时不可用,请稍后重试");
  627. } finally {
  628. vi.useRealTimers();
  629. }
  630. });
  631. test.each([
  632. {
  633. name: "provider error",
  634. category: ProxyErrorCategory.PROVIDER_ERROR,
  635. errorFactory: (provider: Provider) =>
  636. new UpstreamProxyError("Provider returned 401: invalid key", 401, {
  637. body: '{"error":"invalid_api_key"}',
  638. providerId: provider.id,
  639. providerName: provider.name,
  640. }),
  641. },
  642. {
  643. name: "resource not found",
  644. category: ProxyErrorCategory.RESOURCE_NOT_FOUND,
  645. errorFactory: (provider: Provider) =>
  646. new UpstreamProxyError("Provider returned 404: model not found", 404, {
  647. body: '{"error":"model_not_found"}',
  648. providerId: provider.id,
  649. providerName: provider.name,
  650. }),
  651. },
  652. {
  653. name: "system error",
  654. category: ProxyErrorCategory.SYSTEM_ERROR,
  655. errorFactory: () => new Error("fetch failed"),
  656. },
  657. ])("when a real hedge race ends with only $name, terminal error should be generic fallback", async ({
  658. category,
  659. errorFactory,
  660. }) => {
  661. vi.useFakeTimers();
  662. try {
  663. const provider1 = createProvider({ id: 1, name: "p1", firstByteTimeoutStreamingMs: 100 });
  664. const provider2 = createProvider({ id: 2, name: "p2", firstByteTimeoutStreamingMs: 100 });
  665. const session = createSession();
  666. session.setProvider(provider1);
  667. mocks.pickRandomProviderWithExclusion
  668. .mockResolvedValueOnce(provider2)
  669. .mockResolvedValueOnce(null);
  670. mocks.categorizeErrorAsync.mockResolvedValueOnce(category).mockResolvedValueOnce(category);
  671. const doForward = vi.spyOn(
  672. ProxyForwarder as unknown as {
  673. doForward: (...args: unknown[]) => Promise<Response>;
  674. },
  675. "doForward"
  676. );
  677. const controller1 = new AbortController();
  678. const controller2 = new AbortController();
  679. doForward.mockImplementationOnce(async (attemptSession) => {
  680. const runtime = attemptSession as ProxySession & AttemptRuntime;
  681. runtime.responseController = controller1;
  682. runtime.clearResponseTimeout = vi.fn();
  683. return createDelayedFailure({
  684. delayMs: 150,
  685. error: errorFactory(provider1),
  686. controller: controller1,
  687. });
  688. });
  689. doForward.mockImplementationOnce(async (attemptSession) => {
  690. const runtime = attemptSession as ProxySession & AttemptRuntime;
  691. runtime.responseController = controller2;
  692. runtime.clearResponseTimeout = vi.fn();
  693. return createDelayedFailure({
  694. delayMs: 160,
  695. error: errorFactory(provider2),
  696. controller: controller2,
  697. });
  698. });
  699. const responsePromise = ProxyForwarder.send(session);
  700. const errorPromise = responsePromise.catch((rejection) => rejection as UpstreamProxyError);
  701. await vi.advanceTimersByTimeAsync(100);
  702. expect(doForward).toHaveBeenCalledTimes(2);
  703. await vi.runAllTimersAsync();
  704. const error = await errorPromise;
  705. expect(error).toBeInstanceOf(UpstreamProxyError);
  706. expect(error.statusCode).toBe(503);
  707. expect(error.message).toBe("所有供应商暂时不可用,请稍后重试");
  708. expect(error.message).not.toContain("invalid key");
  709. expect(error.message).not.toContain("model not found");
  710. expect(mocks.clearSessionProvider).toHaveBeenCalledWith("sess-hedge");
  711. } finally {
  712. vi.useRealTimers();
  713. }
  714. });
  715. test("non-retryable client errors should stop hedge immediately and preserve original error", async () => {
  716. const provider1 = createProvider({ id: 1, name: "p1", firstByteTimeoutStreamingMs: 100 });
  717. const provider2 = createProvider({ id: 2, name: "p2", firstByteTimeoutStreamingMs: 100 });
  718. const session = createSession();
  719. session.setProvider(provider1);
  720. const originalError = new UpstreamProxyError("prompt too long", 400, {
  721. body: '{"error":"prompt_too_long"}',
  722. providerId: provider1.id,
  723. providerName: provider1.name,
  724. });
  725. mocks.pickRandomProviderWithExclusion.mockResolvedValueOnce(provider2);
  726. mocks.categorizeErrorAsync.mockResolvedValueOnce(ProxyErrorCategory.NON_RETRYABLE_CLIENT_ERROR);
  727. const doForward = vi.spyOn(
  728. ProxyForwarder as unknown as {
  729. doForward: (...args: unknown[]) => Promise<Response>;
  730. },
  731. "doForward"
  732. );
  733. doForward.mockRejectedValueOnce(originalError);
  734. const error = await ProxyForwarder.send(session).catch(
  735. (rejection) => rejection as UpstreamProxyError
  736. );
  737. expect(error).toBe(originalError);
  738. expect(error.message).toBe("prompt too long");
  739. expect(doForward).toHaveBeenCalledTimes(1);
  740. expect(mocks.pickRandomProviderWithExclusion).not.toHaveBeenCalled();
  741. expect(mocks.clearSessionProvider).toHaveBeenCalledWith("sess-hedge");
  742. expect(session.getProviderChain()).toEqual(
  743. expect.arrayContaining([
  744. expect.objectContaining({
  745. reason: "client_error_non_retryable",
  746. statusCode: 400,
  747. }),
  748. ])
  749. );
  750. });
  751. test("hedge 备选供应商命中 thinking signature 错误时,应整流后在同供应商重试并保留审计", async () => {
  752. vi.useFakeTimers();
  753. try {
  754. const provider1 = createProvider({ id: 1, name: "p1", firstByteTimeoutStreamingMs: 100 });
  755. const provider2 = createProvider({ id: 2, name: "p2", firstByteTimeoutStreamingMs: 100 });
  756. const session = createSession();
  757. session.setProvider(provider1);
  758. withThinkingBlocks(session);
  759. mocks.pickRandomProviderWithExclusion.mockResolvedValueOnce(provider2);
  760. mocks.categorizeErrorAsync.mockResolvedValue(ProxyErrorCategory.NON_RETRYABLE_CLIENT_ERROR);
  761. const signatureError = new UpstreamProxyError(
  762. "Invalid `signature` in `thinking` block",
  763. 400,
  764. {
  765. body: '{"error":"invalid_signature"}',
  766. providerId: provider2.id,
  767. providerName: provider2.name,
  768. }
  769. );
  770. const doForward = vi.spyOn(
  771. ProxyForwarder as unknown as {
  772. doForward: (...args: unknown[]) => Promise<Response>;
  773. },
  774. "doForward"
  775. );
  776. const controller1 = new AbortController();
  777. const controller2First = new AbortController();
  778. const controller2Retry = new AbortController();
  779. doForward.mockImplementationOnce(async (attemptSession) => {
  780. const runtime = attemptSession as ProxySession & AttemptRuntime;
  781. runtime.responseController = controller1;
  782. runtime.clearResponseTimeout = vi.fn();
  783. return createStreamingResponse({
  784. label: "p1",
  785. firstChunkDelayMs: 600,
  786. controller: controller1,
  787. });
  788. });
  789. doForward.mockImplementationOnce(async (attemptSession) => {
  790. const runtime = attemptSession as ProxySession & AttemptRuntime;
  791. runtime.responseController = controller2First;
  792. runtime.clearResponseTimeout = vi.fn();
  793. return createDelayedFailure({
  794. delayMs: 50,
  795. error: signatureError,
  796. controller: controller2First,
  797. });
  798. });
  799. doForward.mockImplementationOnce(async (attemptSession) => {
  800. const runtime = attemptSession as ProxySession & AttemptRuntime;
  801. const body = runtime.request.message as {
  802. messages: Array<{ content: Array<Record<string, unknown>> }>;
  803. };
  804. const blocks = body.messages[0].content;
  805. expect(blocks.some((block) => block.type === "thinking")).toBe(false);
  806. expect(blocks.some((block) => block.type === "redacted_thinking")).toBe(false);
  807. expect(blocks.some((block) => "signature" in block)).toBe(false);
  808. runtime.responseController = controller2Retry;
  809. runtime.clearResponseTimeout = vi.fn();
  810. return createStreamingResponse({
  811. label: "p2-rectified",
  812. firstChunkDelayMs: 180,
  813. controller: controller2Retry,
  814. });
  815. });
  816. const responsePromise = ProxyForwarder.send(session);
  817. await vi.advanceTimersByTimeAsync(100);
  818. expect(doForward).toHaveBeenCalledTimes(2);
  819. await vi.advanceTimersByTimeAsync(55);
  820. expect(doForward).toHaveBeenCalledTimes(3);
  821. await vi.advanceTimersByTimeAsync(200);
  822. const response = await responsePromise;
  823. expect(await response.text()).toContain('"provider":"p2-rectified"');
  824. expect(session.provider?.id).toBe(2);
  825. expect(controller1.signal.aborted).toBe(true);
  826. expect(mocks.pickRandomProviderWithExclusion).toHaveBeenCalled();
  827. expect(mocks.storeSessionSpecialSettings).toHaveBeenCalledWith(
  828. "sess-hedge",
  829. expect.arrayContaining([
  830. expect.objectContaining({
  831. type: "thinking_signature_rectifier",
  832. hit: true,
  833. providerId: 2,
  834. }),
  835. ]),
  836. 1
  837. );
  838. } finally {
  839. vi.useRealTimers();
  840. }
  841. });
  842. test("hedge 路径命中 thinking budget 错误时,应整流后在同供应商重试", async () => {
  843. vi.useFakeTimers();
  844. try {
  845. const provider1 = createProvider({ id: 1, name: "p1", firstByteTimeoutStreamingMs: 100 });
  846. const provider2 = createProvider({ id: 2, name: "p2", firstByteTimeoutStreamingMs: 100 });
  847. const session = createSession();
  848. session.setProvider(provider1);
  849. session.request.message = {
  850. model: "claude-test",
  851. stream: true,
  852. max_tokens: 1000,
  853. thinking: { type: "enabled", budget_tokens: 500 },
  854. messages: [{ role: "user", content: "hi" }],
  855. };
  856. mocks.pickRandomProviderWithExclusion.mockResolvedValueOnce(provider2);
  857. mocks.categorizeErrorAsync.mockResolvedValue(ProxyErrorCategory.NON_RETRYABLE_CLIENT_ERROR);
  858. const budgetError = new UpstreamProxyError(
  859. "thinking.enabled.budget_tokens: Input should be greater than or equal to 1024",
  860. 400,
  861. {
  862. body: '{"error":"budget_too_low"}',
  863. providerId: provider1.id,
  864. providerName: provider1.name,
  865. }
  866. );
  867. const doForward = vi.spyOn(
  868. ProxyForwarder as unknown as {
  869. doForward: (...args: unknown[]) => Promise<Response>;
  870. },
  871. "doForward"
  872. );
  873. const controller1First = new AbortController();
  874. const controller1Retry = new AbortController();
  875. const controller2 = new AbortController();
  876. doForward.mockImplementationOnce(async (attemptSession) => {
  877. const runtime = attemptSession as ProxySession & AttemptRuntime;
  878. runtime.responseController = controller1First;
  879. runtime.clearResponseTimeout = vi.fn();
  880. return createDelayedFailure({
  881. delayMs: 140,
  882. error: budgetError,
  883. controller: controller1First,
  884. });
  885. });
  886. doForward.mockImplementationOnce(async (attemptSession) => {
  887. const runtime = attemptSession as ProxySession & AttemptRuntime;
  888. runtime.responseController = controller2;
  889. runtime.clearResponseTimeout = vi.fn();
  890. return createStreamingResponse({
  891. label: "p2",
  892. firstChunkDelayMs: 500,
  893. controller: controller2,
  894. });
  895. });
  896. doForward.mockImplementationOnce(async (attemptSession) => {
  897. const runtime = attemptSession as ProxySession & AttemptRuntime;
  898. const body = runtime.request.message as {
  899. max_tokens: number;
  900. thinking: { type: string; budget_tokens: number };
  901. };
  902. expect(body.max_tokens).toBe(64000);
  903. expect(body.thinking.type).toBe("enabled");
  904. expect(body.thinking.budget_tokens).toBe(32000);
  905. runtime.responseController = controller1Retry;
  906. runtime.clearResponseTimeout = vi.fn();
  907. return createStreamingResponse({
  908. label: "p1-budget-rectified",
  909. firstChunkDelayMs: 40,
  910. controller: controller1Retry,
  911. });
  912. });
  913. const responsePromise = ProxyForwarder.send(session);
  914. await vi.advanceTimersByTimeAsync(100);
  915. expect(doForward).toHaveBeenCalledTimes(2);
  916. await vi.advanceTimersByTimeAsync(45);
  917. expect(doForward).toHaveBeenCalledTimes(3);
  918. await vi.advanceTimersByTimeAsync(50);
  919. const response = await responsePromise;
  920. expect(await response.text()).toContain('"provider":"p1-budget-rectified"');
  921. expect(session.provider?.id).toBe(1);
  922. expect(mocks.pickRandomProviderWithExclusion).toHaveBeenCalledTimes(1);
  923. expect(session.getSpecialSettings()).toEqual(
  924. expect.arrayContaining([
  925. expect.objectContaining({
  926. type: "thinking_budget_rectifier",
  927. hit: true,
  928. providerId: 1,
  929. }),
  930. ])
  931. );
  932. } finally {
  933. vi.useRealTimers();
  934. }
  935. });
  936. test("endpoint resolution failure should not inflate launchedProviderCount, winner gets request_success not hedge_winner", async () => {
  937. vi.useFakeTimers();
  938. try {
  939. const provider1 = createProvider({
  940. id: 1,
  941. name: "p1",
  942. providerVendorId: 123,
  943. firstByteTimeoutStreamingMs: 100,
  944. });
  945. const provider2 = createProvider({
  946. id: 2,
  947. name: "p2",
  948. providerVendorId: null,
  949. firstByteTimeoutStreamingMs: 100,
  950. });
  951. const session = createSession();
  952. session.requestUrl = new URL("https://example.com/v1/messages");
  953. session.setProvider(provider1);
  954. // Provider 1's strict endpoint resolution will fail
  955. mocks.getPreferredProviderEndpoints.mockRejectedValueOnce(
  956. new Error("Endpoint resolution failed")
  957. );
  958. // After provider 1 fails, pick provider 2 as alternative
  959. mocks.pickRandomProviderWithExclusion.mockResolvedValueOnce(provider2);
  960. const doForward = vi.spyOn(
  961. ProxyForwarder as unknown as {
  962. doForward: (...args: unknown[]) => Promise<Response>;
  963. },
  964. "doForward"
  965. );
  966. const controller2 = new AbortController();
  967. // Only provider 2 reaches doForward (provider 1 fails at endpoint resolution)
  968. doForward.mockImplementationOnce(async (attemptSession) => {
  969. const runtime = attemptSession as ProxySession & AttemptRuntime;
  970. runtime.responseController = controller2;
  971. runtime.clearResponseTimeout = vi.fn();
  972. return createStreamingResponse({
  973. label: "p2",
  974. firstChunkDelayMs: 10,
  975. controller: controller2,
  976. });
  977. });
  978. const responsePromise = ProxyForwarder.send(session);
  979. await vi.advanceTimersByTimeAsync(200);
  980. const response = await responsePromise;
  981. expect(await response.text()).toContain('"provider":"p2"');
  982. expect(session.provider?.id).toBe(2);
  983. // Key assertion: since only provider 2 actually launched (provider 1 failed at
  984. // endpoint resolution before incrementing launchedProviderCount), the winner
  985. // should be classified as "request_success" not "hedge_winner".
  986. const chain = session.getProviderChain();
  987. const winnerEntry = chain.find(
  988. (entry) => entry.reason === "request_success" || entry.reason === "hedge_winner"
  989. );
  990. expect(winnerEntry).toBeDefined();
  991. expect(winnerEntry!.reason).toBe("request_success");
  992. } finally {
  993. vi.useRealTimers();
  994. }
  995. });
  996. });