response-handler-endpoint-circuit-isolation.test.ts 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469
  1. /**
  2. * Tests for endpoint circuit breaker isolation in response-handler.ts
  3. *
  4. * Verifies that key-level errors (fake 200, non-200 HTTP, stream abort) do NOT
  5. * call recordEndpointFailure. Only forwarder-level failures (timeout, network
  6. * error) and probe failures should penalize the endpoint circuit breaker.
  7. *
  8. * Streaming success DOES call recordEndpointSuccess (regression guard).
  9. */
  10. import { beforeEach, describe, expect, it, vi } from "vitest";
  11. import { resolveEndpointPolicy } from "@/app/v1/_lib/proxy/endpoint-policy";
  12. import type { ModelPriceData } from "@/types/model-price";
  13. // Track async tasks for draining
  14. const asyncTasks: Promise<void>[] = [];
  15. vi.mock("@/lib/async-task-manager", () => ({
  16. AsyncTaskManager: {
  17. register: (_taskId: string, promise: Promise<void>) => {
  18. asyncTasks.push(promise);
  19. return new AbortController();
  20. },
  21. cleanup: () => {},
  22. cancel: () => {},
  23. },
  24. }));
  25. vi.mock("@/lib/logger", () => ({
  26. logger: {
  27. debug: () => {},
  28. info: () => {},
  29. warn: () => {},
  30. error: () => {},
  31. trace: () => {},
  32. },
  33. }));
  34. vi.mock("@/lib/price-sync/cloud-price-updater", () => ({
  35. requestCloudPriceTableSync: () => {},
  36. }));
  37. vi.mock("@/repository/model-price", () => ({
  38. findLatestPriceByModel: vi.fn(),
  39. }));
  40. vi.mock("@/repository/system-config", () => ({
  41. getSystemSettings: vi.fn(),
  42. }));
  43. vi.mock("@/repository/message", () => ({
  44. updateMessageRequestCost: vi.fn(),
  45. updateMessageRequestDetails: vi.fn(),
  46. updateMessageRequestDuration: vi.fn(),
  47. }));
  48. vi.mock("@/lib/session-manager", () => ({
  49. SessionManager: {
  50. updateSessionUsage: vi.fn(),
  51. storeSessionResponse: vi.fn(),
  52. clearSessionProvider: vi.fn(),
  53. extractCodexPromptCacheKey: vi.fn(),
  54. updateSessionWithCodexCacheKey: vi.fn(),
  55. },
  56. }));
  57. vi.mock("@/lib/rate-limit", () => ({
  58. RateLimitService: {
  59. trackCost: vi.fn(),
  60. trackUserDailyCost: vi.fn(),
  61. decrementLeaseBudget: vi.fn(),
  62. },
  63. }));
  64. vi.mock("@/lib/session-tracker", () => ({
  65. SessionTracker: {
  66. refreshSession: vi.fn(),
  67. },
  68. }));
  69. vi.mock("@/lib/proxy-status-tracker", () => ({
  70. ProxyStatusTracker: {
  71. getInstance: () => ({
  72. endRequest: () => {},
  73. }),
  74. },
  75. }));
  76. // Mock circuit breakers with tracked spies (vi.hoisted to avoid TDZ with vi.mock hoisting)
  77. const { mockRecordFailure, mockRecordEndpointFailure, mockRecordEndpointSuccess } = vi.hoisted(
  78. () => ({
  79. mockRecordFailure: vi.fn(),
  80. mockRecordEndpointFailure: vi.fn(),
  81. mockRecordEndpointSuccess: vi.fn(),
  82. })
  83. );
  84. vi.mock("@/lib/circuit-breaker", () => ({
  85. recordFailure: mockRecordFailure,
  86. }));
  87. vi.mock("@/lib/endpoint-circuit-breaker", () => ({
  88. recordEndpointFailure: mockRecordEndpointFailure,
  89. recordEndpointSuccess: mockRecordEndpointSuccess,
  90. resetEndpointCircuit: vi.fn(),
  91. }));
  92. import { ProxyResponseHandler } from "@/app/v1/_lib/proxy/response-handler";
  93. import { ProxySession } from "@/app/v1/_lib/proxy/session";
  94. import { setDeferredStreamingFinalization } from "@/app/v1/_lib/proxy/stream-finalization";
  95. import { getSystemSettings } from "@/repository/system-config";
  96. import { findLatestPriceByModel } from "@/repository/model-price";
  97. import { updateMessageRequestDetails, updateMessageRequestDuration } from "@/repository/message";
  98. import { SessionManager } from "@/lib/session-manager";
  99. import { RateLimitService } from "@/lib/rate-limit";
  100. import { SessionTracker } from "@/lib/session-tracker";
  101. const testPriceData: ModelPriceData = {
  102. input_cost_per_token: 0.000003,
  103. output_cost_per_token: 0.000015,
  104. };
  105. function createSession(opts?: { sessionId?: string | null }): ProxySession {
  106. const session = Object.create(ProxySession.prototype) as ProxySession;
  107. const provider = {
  108. id: 1,
  109. name: "test-provider",
  110. providerType: "claude" as const,
  111. baseUrl: "https://api.test.com",
  112. priority: 10,
  113. weight: 1,
  114. costMultiplier: 1,
  115. groupTag: "default",
  116. isEnabled: true,
  117. models: [],
  118. createdAt: new Date(),
  119. updatedAt: new Date(),
  120. streamingIdleTimeoutMs: 0,
  121. dailyResetTime: "00:00",
  122. dailyResetMode: "fixed",
  123. };
  124. const user = { id: 123, name: "test-user", dailyResetTime: "00:00", dailyResetMode: "fixed" };
  125. const key = { id: 456, name: "test-key", dailyResetTime: "00:00", dailyResetMode: "fixed" };
  126. Object.assign(session, {
  127. request: { message: {}, log: "(test)", model: "test-model" },
  128. startTime: Date.now(),
  129. method: "POST",
  130. requestUrl: new URL("http://localhost/v1/messages"),
  131. headers: new Headers(),
  132. headerLog: "",
  133. userAgent: null,
  134. context: {},
  135. clientAbortSignal: null,
  136. userName: "test-user",
  137. authState: { user, key, apiKey: "sk-test", success: true },
  138. provider,
  139. messageContext: {
  140. id: 1,
  141. createdAt: new Date(),
  142. user,
  143. key,
  144. apiKey: "sk-test",
  145. },
  146. sessionId: opts?.sessionId ?? "fake-session",
  147. requestSequence: 1,
  148. originalFormat: "claude",
  149. providerType: null,
  150. originalModelName: null,
  151. originalUrlPathname: null,
  152. providerChain: [],
  153. cacheTtlResolved: null,
  154. context1mApplied: false,
  155. specialSettings: [],
  156. cachedPriceData: undefined,
  157. cachedBillingModelSource: undefined,
  158. endpointPolicy: resolveEndpointPolicy("/v1/messages"),
  159. isHeaderModified: () => false,
  160. getContext1mApplied: () => false,
  161. getOriginalModel: () => "test-model",
  162. getCurrentModel: () => "test-model",
  163. getProviderChain: () => session.providerChain,
  164. getCachedPriceDataByBillingSource: async () => testPriceData,
  165. recordTtfb: () => 100,
  166. ttfbMs: null,
  167. getRequestSequence: () => 1,
  168. addProviderToChain: function (
  169. this: ProxySession & { providerChain: Record<string, unknown>[] },
  170. prov: {
  171. id: number;
  172. name: string;
  173. providerType: string;
  174. priority: number;
  175. weight: number;
  176. costMultiplier: number;
  177. groupTag: string;
  178. providerVendorId?: string;
  179. },
  180. metadata?: Record<string, unknown>
  181. ) {
  182. this.providerChain.push({
  183. id: prov.id,
  184. name: prov.name,
  185. vendorId: prov.providerVendorId,
  186. providerType: prov.providerType,
  187. priority: prov.priority,
  188. weight: prov.weight,
  189. costMultiplier: prov.costMultiplier,
  190. groupTag: prov.groupTag,
  191. timestamp:
  192. typeof metadata?.timestamp === "number" && Number.isFinite(metadata.timestamp)
  193. ? metadata.timestamp
  194. : Date.now(),
  195. ...(metadata ?? {}),
  196. });
  197. },
  198. });
  199. // Helper setters
  200. (session as { setOriginalModel(m: string | null): void }).setOriginalModel = function (
  201. m: string | null
  202. ) {
  203. (this as { originalModelName: string | null }).originalModelName = m;
  204. };
  205. (session as { setSessionId(s: string): void }).setSessionId = function (s: string) {
  206. (this as { sessionId: string | null }).sessionId = s;
  207. };
  208. (session as { setProvider(p: unknown): void }).setProvider = function (p: unknown) {
  209. (this as { provider: unknown }).provider = p;
  210. };
  211. (session as { setAuthState(a: unknown): void }).setAuthState = function (a: unknown) {
  212. (this as { authState: unknown }).authState = a;
  213. };
  214. (session as { setMessageContext(c: unknown): void }).setMessageContext = function (c: unknown) {
  215. (this as { messageContext: unknown }).messageContext = c;
  216. };
  217. session.setOriginalModel("test-model");
  218. return session;
  219. }
  220. function setDeferredMeta(session: ProxySession, endpointId: number | null = 42) {
  221. setDeferredStreamingFinalization(session, {
  222. providerId: 1,
  223. providerName: "test-provider",
  224. providerPriority: 10,
  225. attemptNumber: 1,
  226. totalProvidersAttempted: 1,
  227. isFirstAttempt: true,
  228. isFailoverSuccess: false,
  229. endpointId,
  230. endpointUrl: "https://api.test.com",
  231. upstreamStatusCode: 200,
  232. });
  233. }
  234. /** Create an SSE stream that emits a fake-200 error body (valid HTTP 200 but error in content). */
  235. function createFake200StreamResponse(errorMessage: string = "invalid api key"): Response {
  236. const body = `data: ${JSON.stringify({ error: { message: errorMessage } })}\n\n`;
  237. const encoder = new TextEncoder();
  238. const stream = new ReadableStream<Uint8Array>({
  239. start(controller) {
  240. controller.enqueue(encoder.encode(body));
  241. controller.close();
  242. },
  243. });
  244. return new Response(stream, {
  245. status: 200,
  246. headers: { "content-type": "text/event-stream" },
  247. });
  248. }
  249. /** Create an SSE stream that returns non-200 HTTP status with error body. */
  250. function createNon200StreamResponse(statusCode: number): Response {
  251. const body = `data: ${JSON.stringify({ error: "rate limit exceeded" })}\n\n`;
  252. const encoder = new TextEncoder();
  253. const stream = new ReadableStream<Uint8Array>({
  254. start(controller) {
  255. controller.enqueue(encoder.encode(body));
  256. controller.close();
  257. },
  258. });
  259. return new Response(stream, {
  260. status: statusCode,
  261. headers: { "content-type": "text/event-stream" },
  262. });
  263. }
  264. /** Create a successful SSE stream with usage data. */
  265. function createSuccessStreamResponse(): Response {
  266. const sseText = `event: message_delta\ndata: ${JSON.stringify({ usage: { input_tokens: 100, output_tokens: 50 } })}\n\n`;
  267. const encoder = new TextEncoder();
  268. const stream = new ReadableStream<Uint8Array>({
  269. start(controller) {
  270. controller.enqueue(encoder.encode(sseText));
  271. controller.close();
  272. },
  273. });
  274. return new Response(stream, {
  275. status: 200,
  276. headers: { "content-type": "text/event-stream" },
  277. });
  278. }
  279. async function drainAsyncTasks(): Promise<void> {
  280. const tasks = asyncTasks.splice(0, asyncTasks.length);
  281. await Promise.all(tasks);
  282. }
  283. function setupCommonMocks() {
  284. vi.mocked(getSystemSettings).mockResolvedValue({
  285. billingModelSource: "original",
  286. streamBufferEnabled: false,
  287. streamBufferMode: "none",
  288. streamBufferSize: 0,
  289. } as ReturnType<typeof getSystemSettings> extends Promise<infer T> ? T : never);
  290. vi.mocked(findLatestPriceByModel).mockResolvedValue({
  291. id: 1,
  292. modelName: "test-model",
  293. priceData: testPriceData,
  294. createdAt: new Date(),
  295. updatedAt: new Date(),
  296. });
  297. vi.mocked(updateMessageRequestDetails).mockResolvedValue(undefined);
  298. vi.mocked(updateMessageRequestDuration).mockResolvedValue(undefined);
  299. vi.mocked(SessionManager.storeSessionResponse).mockResolvedValue(undefined);
  300. vi.mocked(SessionManager.clearSessionProvider).mockResolvedValue(undefined);
  301. vi.mocked(RateLimitService.trackCost).mockResolvedValue(undefined);
  302. vi.mocked(RateLimitService.trackUserDailyCost).mockResolvedValue(undefined);
  303. vi.mocked(RateLimitService.decrementLeaseBudget).mockResolvedValue({
  304. success: true,
  305. newRemaining: 10,
  306. });
  307. vi.mocked(SessionTracker.refreshSession).mockResolvedValue(undefined);
  308. mockRecordFailure.mockResolvedValue(undefined);
  309. mockRecordEndpointFailure.mockResolvedValue(undefined);
  310. mockRecordEndpointSuccess.mockResolvedValue(undefined);
  311. }
  312. beforeEach(() => {
  313. vi.clearAllMocks();
  314. asyncTasks.splice(0, asyncTasks.length);
  315. });
  316. describe("Endpoint circuit breaker isolation", () => {
  317. beforeEach(() => {
  318. setupCommonMocks();
  319. });
  320. it("fake-200 error should call recordFailure but NOT recordEndpointFailure", async () => {
  321. const session = createSession();
  322. setDeferredMeta(session, 42);
  323. const response = createFake200StreamResponse();
  324. await ProxyResponseHandler.dispatch(session, response);
  325. await drainAsyncTasks();
  326. expect(mockRecordFailure).toHaveBeenCalledWith(
  327. 1,
  328. expect.objectContaining({ message: expect.stringContaining("FAKE_200") })
  329. );
  330. expect(mockRecordEndpointFailure).not.toHaveBeenCalled();
  331. expect(SessionManager.clearSessionProvider).toHaveBeenCalledWith("fake-session");
  332. const chain = session.getProviderChain();
  333. expect(
  334. chain.some(
  335. (item) =>
  336. item.id === 1 &&
  337. item.reason === "retry_failed" &&
  338. item.statusCode === 401 &&
  339. item.statusCodeInferred === true
  340. )
  341. ).toBe(true);
  342. });
  343. it("高并发模式下,fake-200 流式错误仍应记录核心失败,但跳过 session 观测写入", async () => {
  344. const session = createSession();
  345. session.setHighConcurrencyModeEnabled(true);
  346. setDeferredMeta(session, 42);
  347. const response = createFake200StreamResponse();
  348. await ProxyResponseHandler.dispatch(session, response);
  349. await drainAsyncTasks();
  350. expect(mockRecordFailure).toHaveBeenCalledWith(
  351. 1,
  352. expect.objectContaining({ message: expect.stringContaining("FAKE_200") })
  353. );
  354. expect(mockRecordEndpointFailure).not.toHaveBeenCalled();
  355. expect(SessionManager.clearSessionProvider).toHaveBeenCalledWith("fake-session");
  356. expect(SessionManager.updateSessionUsage).not.toHaveBeenCalled();
  357. expect(SessionTracker.refreshSession).not.toHaveBeenCalled();
  358. });
  359. it("fake-200 inferred 404 should NOT call recordFailure and should be marked as resource_not_found", async () => {
  360. const session = createSession();
  361. setDeferredMeta(session, 42);
  362. const response = createFake200StreamResponse("model not found");
  363. await ProxyResponseHandler.dispatch(session, response);
  364. await drainAsyncTasks();
  365. expect(mockRecordFailure).not.toHaveBeenCalled();
  366. expect(mockRecordEndpointFailure).not.toHaveBeenCalled();
  367. expect(SessionManager.clearSessionProvider).toHaveBeenCalledWith("fake-session");
  368. const chain = session.getProviderChain();
  369. expect(
  370. chain.some(
  371. (item) =>
  372. item.id === 1 &&
  373. item.reason === "resource_not_found" &&
  374. item.statusCode === 404 &&
  375. item.statusCodeInferred === true
  376. )
  377. ).toBe(true);
  378. });
  379. it("non-200 HTTP status should call recordFailure but NOT recordEndpointFailure", async () => {
  380. const session = createSession();
  381. // Set upstream status to 429 in deferred meta
  382. setDeferredStreamingFinalization(session, {
  383. providerId: 1,
  384. providerName: "test-provider",
  385. providerPriority: 10,
  386. attemptNumber: 1,
  387. totalProvidersAttempted: 1,
  388. isFirstAttempt: true,
  389. isFailoverSuccess: false,
  390. endpointId: 42,
  391. endpointUrl: "https://api.test.com",
  392. upstreamStatusCode: 429,
  393. });
  394. const response = createNon200StreamResponse(429);
  395. await ProxyResponseHandler.dispatch(session, response);
  396. await drainAsyncTasks();
  397. expect(mockRecordFailure).toHaveBeenCalledWith(1, expect.any(Error));
  398. expect(mockRecordEndpointFailure).not.toHaveBeenCalled();
  399. });
  400. it("streaming success DOES call recordEndpointSuccess (regression guard)", async () => {
  401. const session = createSession();
  402. setDeferredMeta(session, 42);
  403. const response = createSuccessStreamResponse();
  404. await ProxyResponseHandler.dispatch(session, response);
  405. await drainAsyncTasks();
  406. expect(mockRecordEndpointSuccess).toHaveBeenCalledWith(42);
  407. expect(mockRecordEndpointFailure).not.toHaveBeenCalled();
  408. });
  409. it("streaming success without endpointId should NOT call any endpoint circuit breaker function", async () => {
  410. const session = createSession();
  411. setDeferredMeta(session, null);
  412. const response = createSuccessStreamResponse();
  413. await ProxyResponseHandler.dispatch(session, response);
  414. await drainAsyncTasks();
  415. expect(mockRecordEndpointSuccess).not.toHaveBeenCalled();
  416. expect(mockRecordEndpointFailure).not.toHaveBeenCalled();
  417. });
  418. });