2
0

response-handler-endpoint-circuit-isolation.test.ts 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446
  1. /**
  2. * Tests for endpoint circuit breaker isolation in response-handler.ts
  3. *
  4. * Verifies that key-level errors (fake 200, non-200 HTTP, stream abort) do NOT
  5. * call recordEndpointFailure. Only forwarder-level failures (timeout, network
  6. * error) and probe failures should penalize the endpoint circuit breaker.
  7. *
  8. * Streaming success DOES call recordEndpointSuccess (regression guard).
  9. */
  10. import { beforeEach, describe, expect, it, vi } from "vitest";
  11. import { resolveEndpointPolicy } from "@/app/v1/_lib/proxy/endpoint-policy";
  12. import type { ModelPriceData } from "@/types/model-price";
  13. // Track async tasks for draining
  14. const asyncTasks: Promise<void>[] = [];
  15. vi.mock("@/lib/async-task-manager", () => ({
  16. AsyncTaskManager: {
  17. register: (_taskId: string, promise: Promise<void>) => {
  18. asyncTasks.push(promise);
  19. return new AbortController();
  20. },
  21. cleanup: () => {},
  22. cancel: () => {},
  23. },
  24. }));
  25. vi.mock("@/lib/logger", () => ({
  26. logger: {
  27. debug: () => {},
  28. info: () => {},
  29. warn: () => {},
  30. error: () => {},
  31. trace: () => {},
  32. },
  33. }));
  34. vi.mock("@/lib/price-sync/cloud-price-updater", () => ({
  35. requestCloudPriceTableSync: () => {},
  36. }));
  37. vi.mock("@/repository/model-price", () => ({
  38. findLatestPriceByModel: vi.fn(),
  39. }));
  40. vi.mock("@/repository/system-config", () => ({
  41. getSystemSettings: vi.fn(),
  42. }));
  43. vi.mock("@/repository/message", () => ({
  44. updateMessageRequestCost: vi.fn(),
  45. updateMessageRequestDetails: vi.fn(),
  46. updateMessageRequestDuration: vi.fn(),
  47. }));
  48. vi.mock("@/lib/session-manager", () => ({
  49. SessionManager: {
  50. updateSessionUsage: vi.fn(),
  51. storeSessionResponse: vi.fn(),
  52. extractCodexPromptCacheKey: vi.fn(),
  53. updateSessionWithCodexCacheKey: vi.fn(),
  54. },
  55. }));
  56. vi.mock("@/lib/rate-limit", () => ({
  57. RateLimitService: {
  58. trackCost: vi.fn(),
  59. trackUserDailyCost: vi.fn(),
  60. decrementLeaseBudget: vi.fn(),
  61. },
  62. }));
  63. vi.mock("@/lib/session-tracker", () => ({
  64. SessionTracker: {
  65. refreshSession: vi.fn(),
  66. },
  67. }));
  68. vi.mock("@/lib/proxy-status-tracker", () => ({
  69. ProxyStatusTracker: {
  70. getInstance: () => ({
  71. endRequest: () => {},
  72. }),
  73. },
  74. }));
  75. // Mock circuit breakers with tracked spies (vi.hoisted to avoid TDZ with vi.mock hoisting)
  76. const { mockRecordFailure, mockRecordEndpointFailure, mockRecordEndpointSuccess } = vi.hoisted(
  77. () => ({
  78. mockRecordFailure: vi.fn(),
  79. mockRecordEndpointFailure: vi.fn(),
  80. mockRecordEndpointSuccess: vi.fn(),
  81. })
  82. );
  83. vi.mock("@/lib/circuit-breaker", () => ({
  84. recordFailure: mockRecordFailure,
  85. }));
  86. vi.mock("@/lib/endpoint-circuit-breaker", () => ({
  87. recordEndpointFailure: mockRecordEndpointFailure,
  88. recordEndpointSuccess: mockRecordEndpointSuccess,
  89. resetEndpointCircuit: vi.fn(),
  90. }));
  91. import { ProxyResponseHandler } from "@/app/v1/_lib/proxy/response-handler";
  92. import { ProxySession } from "@/app/v1/_lib/proxy/session";
  93. import { setDeferredStreamingFinalization } from "@/app/v1/_lib/proxy/stream-finalization";
  94. import { getSystemSettings } from "@/repository/system-config";
  95. import { findLatestPriceByModel } from "@/repository/model-price";
  96. import { updateMessageRequestDetails, updateMessageRequestDuration } from "@/repository/message";
  97. import { SessionManager } from "@/lib/session-manager";
  98. import { RateLimitService } from "@/lib/rate-limit";
  99. import { SessionTracker } from "@/lib/session-tracker";
  100. const testPriceData: ModelPriceData = {
  101. input_cost_per_token: 0.000003,
  102. output_cost_per_token: 0.000015,
  103. };
  104. function createSession(opts?: { sessionId?: string | null }): ProxySession {
  105. const session = Object.create(ProxySession.prototype) as ProxySession;
  106. const provider = {
  107. id: 1,
  108. name: "test-provider",
  109. providerType: "claude" as const,
  110. baseUrl: "https://api.test.com",
  111. priority: 10,
  112. weight: 1,
  113. costMultiplier: 1,
  114. groupTag: "default",
  115. isEnabled: true,
  116. models: [],
  117. createdAt: new Date(),
  118. updatedAt: new Date(),
  119. streamingIdleTimeoutMs: 0,
  120. dailyResetTime: "00:00",
  121. dailyResetMode: "fixed",
  122. };
  123. const user = { id: 123, name: "test-user", dailyResetTime: "00:00", dailyResetMode: "fixed" };
  124. const key = { id: 456, name: "test-key", dailyResetTime: "00:00", dailyResetMode: "fixed" };
  125. Object.assign(session, {
  126. request: { message: {}, log: "(test)", model: "test-model" },
  127. startTime: Date.now(),
  128. method: "POST",
  129. requestUrl: new URL("http://localhost/v1/messages"),
  130. headers: new Headers(),
  131. headerLog: "",
  132. userAgent: null,
  133. context: {},
  134. clientAbortSignal: null,
  135. userName: "test-user",
  136. authState: { user, key, apiKey: "sk-test", success: true },
  137. provider,
  138. messageContext: {
  139. id: 1,
  140. createdAt: new Date(),
  141. user,
  142. key,
  143. apiKey: "sk-test",
  144. },
  145. sessionId: opts?.sessionId ?? null,
  146. requestSequence: 1,
  147. originalFormat: "claude",
  148. providerType: null,
  149. originalModelName: null,
  150. originalUrlPathname: null,
  151. providerChain: [],
  152. cacheTtlResolved: null,
  153. context1mApplied: false,
  154. specialSettings: [],
  155. cachedPriceData: undefined,
  156. cachedBillingModelSource: undefined,
  157. endpointPolicy: resolveEndpointPolicy("/v1/messages"),
  158. isHeaderModified: () => false,
  159. getContext1mApplied: () => false,
  160. getOriginalModel: () => "test-model",
  161. getCurrentModel: () => "test-model",
  162. getProviderChain: () => session.providerChain,
  163. getCachedPriceDataByBillingSource: async () => testPriceData,
  164. recordTtfb: () => 100,
  165. ttfbMs: null,
  166. getRequestSequence: () => 1,
  167. addProviderToChain: function (
  168. this: ProxySession & { providerChain: Record<string, unknown>[] },
  169. prov: {
  170. id: number;
  171. name: string;
  172. providerType: string;
  173. priority: number;
  174. weight: number;
  175. costMultiplier: number;
  176. groupTag: string;
  177. providerVendorId?: string;
  178. },
  179. metadata?: Record<string, unknown>
  180. ) {
  181. this.providerChain.push({
  182. id: prov.id,
  183. name: prov.name,
  184. vendorId: prov.providerVendorId,
  185. providerType: prov.providerType,
  186. priority: prov.priority,
  187. weight: prov.weight,
  188. costMultiplier: prov.costMultiplier,
  189. groupTag: prov.groupTag,
  190. timestamp:
  191. typeof metadata?.timestamp === "number" && Number.isFinite(metadata.timestamp)
  192. ? metadata.timestamp
  193. : Date.now(),
  194. ...(metadata ?? {}),
  195. });
  196. },
  197. });
  198. // Helper setters
  199. (session as { setOriginalModel(m: string | null): void }).setOriginalModel = function (
  200. m: string | null
  201. ) {
  202. (this as { originalModelName: string | null }).originalModelName = m;
  203. };
  204. (session as { setSessionId(s: string): void }).setSessionId = function (s: string) {
  205. (this as { sessionId: string | null }).sessionId = s;
  206. };
  207. (session as { setProvider(p: unknown): void }).setProvider = function (p: unknown) {
  208. (this as { provider: unknown }).provider = p;
  209. };
  210. (session as { setAuthState(a: unknown): void }).setAuthState = function (a: unknown) {
  211. (this as { authState: unknown }).authState = a;
  212. };
  213. (session as { setMessageContext(c: unknown): void }).setMessageContext = function (c: unknown) {
  214. (this as { messageContext: unknown }).messageContext = c;
  215. };
  216. session.setOriginalModel("test-model");
  217. return session;
  218. }
  219. function setDeferredMeta(session: ProxySession, endpointId: number | null = 42) {
  220. setDeferredStreamingFinalization(session, {
  221. providerId: 1,
  222. providerName: "test-provider",
  223. providerPriority: 10,
  224. attemptNumber: 1,
  225. totalProvidersAttempted: 1,
  226. isFirstAttempt: true,
  227. isFailoverSuccess: false,
  228. endpointId,
  229. endpointUrl: "https://api.test.com",
  230. upstreamStatusCode: 200,
  231. });
  232. }
  233. /** Create an SSE stream that emits a fake-200 error body (valid HTTP 200 but error in content). */
  234. function createFake200StreamResponse(errorMessage: string = "invalid api key"): Response {
  235. const body = `data: ${JSON.stringify({ error: { message: errorMessage } })}\n\n`;
  236. const encoder = new TextEncoder();
  237. const stream = new ReadableStream<Uint8Array>({
  238. start(controller) {
  239. controller.enqueue(encoder.encode(body));
  240. controller.close();
  241. },
  242. });
  243. return new Response(stream, {
  244. status: 200,
  245. headers: { "content-type": "text/event-stream" },
  246. });
  247. }
  248. /** Create an SSE stream that returns non-200 HTTP status with error body. */
  249. function createNon200StreamResponse(statusCode: number): Response {
  250. const body = `data: ${JSON.stringify({ error: "rate limit exceeded" })}\n\n`;
  251. const encoder = new TextEncoder();
  252. const stream = new ReadableStream<Uint8Array>({
  253. start(controller) {
  254. controller.enqueue(encoder.encode(body));
  255. controller.close();
  256. },
  257. });
  258. return new Response(stream, {
  259. status: statusCode,
  260. headers: { "content-type": "text/event-stream" },
  261. });
  262. }
  263. /** Create a successful SSE stream with usage data. */
  264. function createSuccessStreamResponse(): Response {
  265. const sseText = `event: message_delta\ndata: ${JSON.stringify({ usage: { input_tokens: 100, output_tokens: 50 } })}\n\n`;
  266. const encoder = new TextEncoder();
  267. const stream = new ReadableStream<Uint8Array>({
  268. start(controller) {
  269. controller.enqueue(encoder.encode(sseText));
  270. controller.close();
  271. },
  272. });
  273. return new Response(stream, {
  274. status: 200,
  275. headers: { "content-type": "text/event-stream" },
  276. });
  277. }
  278. async function drainAsyncTasks(): Promise<void> {
  279. const tasks = asyncTasks.splice(0, asyncTasks.length);
  280. await Promise.all(tasks);
  281. }
  282. function setupCommonMocks() {
  283. vi.mocked(getSystemSettings).mockResolvedValue({
  284. billingModelSource: "original",
  285. streamBufferEnabled: false,
  286. streamBufferMode: "none",
  287. streamBufferSize: 0,
  288. } as ReturnType<typeof getSystemSettings> extends Promise<infer T> ? T : never);
  289. vi.mocked(findLatestPriceByModel).mockResolvedValue({
  290. id: 1,
  291. modelName: "test-model",
  292. priceData: testPriceData,
  293. createdAt: new Date(),
  294. updatedAt: new Date(),
  295. });
  296. vi.mocked(updateMessageRequestDetails).mockResolvedValue(undefined);
  297. vi.mocked(updateMessageRequestDuration).mockResolvedValue(undefined);
  298. vi.mocked(SessionManager.storeSessionResponse).mockResolvedValue(undefined);
  299. vi.mocked(RateLimitService.trackCost).mockResolvedValue(undefined);
  300. vi.mocked(RateLimitService.trackUserDailyCost).mockResolvedValue(undefined);
  301. vi.mocked(RateLimitService.decrementLeaseBudget).mockResolvedValue({
  302. success: true,
  303. newRemaining: 10,
  304. });
  305. vi.mocked(SessionTracker.refreshSession).mockResolvedValue(undefined);
  306. mockRecordFailure.mockResolvedValue(undefined);
  307. mockRecordEndpointFailure.mockResolvedValue(undefined);
  308. mockRecordEndpointSuccess.mockResolvedValue(undefined);
  309. }
  310. beforeEach(() => {
  311. vi.clearAllMocks();
  312. asyncTasks.splice(0, asyncTasks.length);
  313. });
  314. describe("Endpoint circuit breaker isolation", () => {
  315. beforeEach(() => {
  316. setupCommonMocks();
  317. });
  318. it("fake-200 error should call recordFailure but NOT recordEndpointFailure", async () => {
  319. const session = createSession();
  320. setDeferredMeta(session, 42);
  321. const response = createFake200StreamResponse();
  322. await ProxyResponseHandler.dispatch(session, response);
  323. await drainAsyncTasks();
  324. expect(mockRecordFailure).toHaveBeenCalledWith(
  325. 1,
  326. expect.objectContaining({ message: expect.stringContaining("FAKE_200") })
  327. );
  328. expect(mockRecordEndpointFailure).not.toHaveBeenCalled();
  329. const chain = session.getProviderChain();
  330. expect(
  331. chain.some(
  332. (item) =>
  333. item.id === 1 &&
  334. item.reason === "retry_failed" &&
  335. item.statusCode === 401 &&
  336. item.statusCodeInferred === true
  337. )
  338. ).toBe(true);
  339. });
  340. it("fake-200 inferred 404 should NOT call recordFailure and should be marked as resource_not_found", async () => {
  341. const session = createSession();
  342. setDeferredMeta(session, 42);
  343. const response = createFake200StreamResponse("model not found");
  344. await ProxyResponseHandler.dispatch(session, response);
  345. await drainAsyncTasks();
  346. expect(mockRecordFailure).not.toHaveBeenCalled();
  347. expect(mockRecordEndpointFailure).not.toHaveBeenCalled();
  348. const chain = session.getProviderChain();
  349. expect(
  350. chain.some(
  351. (item) =>
  352. item.id === 1 &&
  353. item.reason === "resource_not_found" &&
  354. item.statusCode === 404 &&
  355. item.statusCodeInferred === true
  356. )
  357. ).toBe(true);
  358. });
  359. it("non-200 HTTP status should call recordFailure but NOT recordEndpointFailure", async () => {
  360. const session = createSession();
  361. // Set upstream status to 429 in deferred meta
  362. setDeferredStreamingFinalization(session, {
  363. providerId: 1,
  364. providerName: "test-provider",
  365. providerPriority: 10,
  366. attemptNumber: 1,
  367. totalProvidersAttempted: 1,
  368. isFirstAttempt: true,
  369. isFailoverSuccess: false,
  370. endpointId: 42,
  371. endpointUrl: "https://api.test.com",
  372. upstreamStatusCode: 429,
  373. });
  374. const response = createNon200StreamResponse(429);
  375. await ProxyResponseHandler.dispatch(session, response);
  376. await drainAsyncTasks();
  377. expect(mockRecordFailure).toHaveBeenCalledWith(1, expect.any(Error));
  378. expect(mockRecordEndpointFailure).not.toHaveBeenCalled();
  379. });
  380. it("streaming success DOES call recordEndpointSuccess (regression guard)", async () => {
  381. const session = createSession();
  382. setDeferredMeta(session, 42);
  383. const response = createSuccessStreamResponse();
  384. await ProxyResponseHandler.dispatch(session, response);
  385. await drainAsyncTasks();
  386. expect(mockRecordEndpointSuccess).toHaveBeenCalledWith(42);
  387. expect(mockRecordEndpointFailure).not.toHaveBeenCalled();
  388. });
  389. it("streaming success without endpointId should NOT call any endpoint circuit breaker function", async () => {
  390. const session = createSession();
  391. setDeferredMeta(session, null);
  392. const response = createSuccessStreamResponse();
  393. await ProxyResponseHandler.dispatch(session, response);
  394. await drainAsyncTasks();
  395. expect(mockRecordEndpointSuccess).not.toHaveBeenCalled();
  396. expect(mockRecordEndpointFailure).not.toHaveBeenCalled();
  397. });
  398. });