response-handler-lease-decrement.test.ts 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535
  1. /**
  2. * TDD: RED Phase - Tests for lease budget decrement in response-handler.ts
  3. *
  4. * Tests that decrementLeaseBudget is called correctly after trackCostToRedis completes.
  5. * - All windows: 5h, daily, weekly, monthly
  6. * - All entity types: key, user, provider
  7. * - Zero-cost requests should NOT trigger decrement
  8. * - Function runs once per request (no duplicates)
  9. */
  10. import { beforeEach, describe, expect, it, vi } from "vitest";
  11. import { resolveEndpointPolicy } from "@/app/v1/_lib/proxy/endpoint-policy";
  12. import type { ModelPriceData } from "@/types/model-price";
  13. // Track async tasks for draining
  14. const asyncTasks: Promise<void>[] = [];
  15. vi.mock("@/lib/async-task-manager", () => ({
  16. AsyncTaskManager: {
  17. register: (_taskId: string, promise: Promise<void>) => {
  18. asyncTasks.push(promise);
  19. return new AbortController();
  20. },
  21. cleanup: () => {},
  22. cancel: () => {},
  23. },
  24. }));
  25. vi.mock("@/lib/logger", () => ({
  26. logger: {
  27. debug: () => {},
  28. info: () => {},
  29. warn: () => {},
  30. error: () => {},
  31. trace: () => {},
  32. },
  33. }));
  34. vi.mock("@/lib/price-sync/cloud-price-updater", () => ({
  35. requestCloudPriceTableSync: () => {},
  36. }));
  37. vi.mock("@/repository/model-price", () => ({
  38. findLatestPriceByModel: vi.fn(),
  39. }));
  40. vi.mock("@/repository/system-config", () => ({
  41. getSystemSettings: vi.fn(),
  42. }));
  43. vi.mock("@/repository/message", () => ({
  44. updateMessageRequestCost: vi.fn(),
  45. updateMessageRequestCostWithBreakdown: vi.fn(),
  46. updateMessageRequestDetails: vi.fn(),
  47. updateMessageRequestDuration: vi.fn(),
  48. }));
  49. vi.mock("@/lib/session-manager", () => ({
  50. SessionManager: {
  51. updateSessionUsage: vi.fn(),
  52. storeSessionResponse: vi.fn(),
  53. extractCodexPromptCacheKey: vi.fn(),
  54. updateSessionWithCodexCacheKey: vi.fn(),
  55. },
  56. }));
  57. vi.mock("@/lib/rate-limit", () => ({
  58. RateLimitService: {
  59. trackCost: vi.fn(),
  60. trackUserDailyCost: vi.fn(),
  61. decrementLeaseBudget: vi.fn(),
  62. },
  63. }));
  64. vi.mock("@/lib/session-tracker", () => ({
  65. SessionTracker: {
  66. refreshSession: vi.fn(),
  67. },
  68. }));
  69. vi.mock("@/lib/proxy-status-tracker", () => ({
  70. ProxyStatusTracker: {
  71. getInstance: () => ({
  72. endRequest: () => {},
  73. }),
  74. },
  75. }));
  76. import { ProxyResponseHandler } from "@/app/v1/_lib/proxy/response-handler";
  77. import { ProxySession } from "@/app/v1/_lib/proxy/session";
  78. import { SessionManager } from "@/lib/session-manager";
  79. import { RateLimitService } from "@/lib/rate-limit";
  80. import { SessionTracker } from "@/lib/session-tracker";
  81. import {
  82. updateMessageRequestCost,
  83. updateMessageRequestDetails,
  84. updateMessageRequestDuration,
  85. } from "@/repository/message";
  86. import { findLatestPriceByModel } from "@/repository/model-price";
  87. import { getSystemSettings } from "@/repository/system-config";
  88. // Test price data
  89. const testPriceData: ModelPriceData = {
  90. input_cost_per_token: 0.000003,
  91. output_cost_per_token: 0.000015,
  92. };
  93. function makePriceRecord(modelName: string, priceData: ModelPriceData) {
  94. return {
  95. id: 1,
  96. modelName,
  97. priceData,
  98. createdAt: new Date(),
  99. updatedAt: new Date(),
  100. };
  101. }
  102. function makeSystemSettings(billingModelSource: "original" | "redirected" = "original") {
  103. return {
  104. billingModelSource,
  105. streamBufferEnabled: false,
  106. streamBufferMode: "none",
  107. streamBufferSize: 0,
  108. } as ReturnType<typeof getSystemSettings> extends Promise<infer T> ? T : never;
  109. }
  110. function createSession(opts: {
  111. originalModel: string;
  112. redirectedModel: string;
  113. sessionId: string;
  114. messageId: number;
  115. }): ProxySession {
  116. const { originalModel, redirectedModel, sessionId, messageId } = opts;
  117. const session = Object.create(ProxySession.prototype) as ProxySession;
  118. Object.assign(session, {
  119. request: { message: {}, log: "(test)", model: redirectedModel },
  120. startTime: Date.now(),
  121. method: "POST",
  122. requestUrl: new URL("http://localhost/v1/messages"),
  123. headers: new Headers(),
  124. headerLog: "",
  125. userAgent: null,
  126. context: {},
  127. clientAbortSignal: null,
  128. userName: "test-user",
  129. authState: null,
  130. provider: null,
  131. messageContext: null,
  132. sessionId: null,
  133. requestSequence: 1,
  134. originalFormat: "claude",
  135. providerType: null,
  136. originalModelName: null,
  137. originalUrlPathname: null,
  138. providerChain: [],
  139. cacheTtlResolved: null,
  140. context1mApplied: false,
  141. specialSettings: [],
  142. cachedPriceData: undefined,
  143. cachedBillingModelSource: undefined,
  144. resolvedPricingCache: new Map(),
  145. endpointPolicy: resolveEndpointPolicy("/v1/messages"),
  146. isHeaderModified: () => false,
  147. getContext1mApplied: () => false,
  148. getGroupCostMultiplier: () => 1,
  149. getOriginalModel: () => originalModel,
  150. getCurrentModel: () => redirectedModel,
  151. getProviderChain: () => [],
  152. getResolvedPricingByBillingSource: async () => ({
  153. resolvedModelName: redirectedModel,
  154. resolvedPricingProviderKey: "test-provider",
  155. source: "cloud_exact" as const,
  156. priceData: testPriceData,
  157. }),
  158. recordTtfb: () => 100,
  159. ttfbMs: null,
  160. getRequestSequence: () => 1,
  161. });
  162. (session as { setOriginalModel(m: string | null): void }).setOriginalModel = function (
  163. m: string | null
  164. ) {
  165. (this as { originalModelName: string | null }).originalModelName = m;
  166. };
  167. (session as { setSessionId(s: string): void }).setSessionId = function (s: string) {
  168. (this as { sessionId: string | null }).sessionId = s;
  169. };
  170. (session as { setProvider(p: unknown): void }).setProvider = function (p: unknown) {
  171. (this as { provider: unknown }).provider = p;
  172. };
  173. (session as { setAuthState(a: unknown): void }).setAuthState = function (a: unknown) {
  174. (this as { authState: unknown }).authState = a;
  175. };
  176. (session as { setMessageContext(c: unknown): void }).setMessageContext = function (c: unknown) {
  177. (this as { messageContext: unknown }).messageContext = c;
  178. };
  179. session.setOriginalModel(originalModel);
  180. session.setSessionId(sessionId);
  181. const provider = {
  182. id: 99,
  183. name: "test-provider",
  184. providerType: "claude",
  185. costMultiplier: 1.0,
  186. streamingIdleTimeoutMs: 0,
  187. dailyResetTime: "00:00",
  188. dailyResetMode: "fixed",
  189. } as unknown;
  190. const user = {
  191. id: 123,
  192. name: "test-user",
  193. dailyResetTime: "00:00",
  194. dailyResetMode: "fixed",
  195. } as unknown;
  196. const key = {
  197. id: 456,
  198. name: "test-key",
  199. dailyResetTime: "00:00",
  200. dailyResetMode: "fixed",
  201. } as unknown;
  202. session.setProvider(provider);
  203. session.setAuthState({
  204. user,
  205. key,
  206. apiKey: "sk-test",
  207. success: true,
  208. });
  209. session.setMessageContext({
  210. id: messageId,
  211. createdAt: new Date(),
  212. user,
  213. key,
  214. apiKey: "sk-test",
  215. });
  216. return session;
  217. }
  218. function createNonStreamResponse(usage: { input_tokens: number; output_tokens: number }): Response {
  219. return new Response(
  220. JSON.stringify({
  221. type: "message",
  222. usage,
  223. }),
  224. {
  225. status: 200,
  226. headers: { "content-type": "application/json" },
  227. }
  228. );
  229. }
  230. function createStreamResponse(usage: { input_tokens: number; output_tokens: number }): Response {
  231. const sseText = `event: message_delta\ndata: ${JSON.stringify({ usage })}\n\n`;
  232. const encoder = new TextEncoder();
  233. const stream = new ReadableStream<Uint8Array>({
  234. start(controller) {
  235. controller.enqueue(encoder.encode(sseText));
  236. controller.close();
  237. },
  238. });
  239. return new Response(stream, {
  240. status: 200,
  241. headers: { "content-type": "text/event-stream" },
  242. });
  243. }
  244. async function drainAsyncTasks(): Promise<void> {
  245. const tasks = asyncTasks.splice(0, asyncTasks.length);
  246. await Promise.all(tasks);
  247. }
  248. beforeEach(() => {
  249. vi.clearAllMocks();
  250. asyncTasks.splice(0, asyncTasks.length);
  251. });
  252. describe("Lease Budget Decrement after trackCostToRedis", () => {
  253. const originalModel = "claude-sonnet-4-20250514";
  254. const usage = { input_tokens: 1000, output_tokens: 500 };
  255. beforeEach(async () => {
  256. vi.mocked(getSystemSettings).mockResolvedValue(makeSystemSettings("original"));
  257. vi.mocked(findLatestPriceByModel).mockResolvedValue(
  258. makePriceRecord(originalModel, testPriceData)
  259. );
  260. vi.mocked(updateMessageRequestDetails).mockResolvedValue(undefined);
  261. vi.mocked(updateMessageRequestDuration).mockResolvedValue(undefined);
  262. vi.mocked(SessionManager.storeSessionResponse).mockResolvedValue(undefined);
  263. vi.mocked(RateLimitService.trackCost).mockResolvedValue(undefined);
  264. vi.mocked(RateLimitService.trackUserDailyCost).mockResolvedValue(undefined);
  265. vi.mocked(RateLimitService.decrementLeaseBudget).mockResolvedValue({
  266. success: true,
  267. newRemaining: 10,
  268. });
  269. vi.mocked(SessionTracker.refreshSession).mockResolvedValue(undefined);
  270. });
  271. it("should call decrementLeaseBudget for all windows and entity types (non-stream)", async () => {
  272. const session = createSession({
  273. originalModel,
  274. redirectedModel: originalModel,
  275. sessionId: "sess-lease-test-1",
  276. messageId: 5001,
  277. });
  278. const response = createNonStreamResponse(usage);
  279. await ProxyResponseHandler.dispatch(session, response);
  280. await drainAsyncTasks();
  281. // Expected cost: (1000 * 0.000003) + (500 * 0.000015) = 0.003 + 0.0075 = 0.0105
  282. const expectedCost = 0.0105;
  283. // Should be called 12 times:
  284. // 4 windows x 3 entity types = 12 calls
  285. // Windows: 5h, daily, weekly, monthly
  286. // Entity types: key(456), user(123), provider(99)
  287. expect(RateLimitService.decrementLeaseBudget).toHaveBeenCalled();
  288. const calls = vi.mocked(RateLimitService.decrementLeaseBudget).mock.calls;
  289. expect(calls.length).toBe(12);
  290. // Verify all windows are covered for each entity type
  291. const windows = ["5h", "daily", "weekly", "monthly"];
  292. const entities = [
  293. { id: 456, type: "key" },
  294. { id: 123, type: "user" },
  295. { id: 99, type: "provider" },
  296. ];
  297. for (const entity of entities) {
  298. for (const window of windows) {
  299. const matchingCall = calls.find(
  300. (call) => call[0] === entity.id && call[1] === entity.type && call[2] === window
  301. );
  302. expect(matchingCall).toBeDefined();
  303. // Cost should be approximately 0.0105
  304. expect(matchingCall![3]).toBeCloseTo(expectedCost, 4);
  305. }
  306. }
  307. });
  308. it("should call decrementLeaseBudget for all windows and entity types (stream)", async () => {
  309. const session = createSession({
  310. originalModel,
  311. redirectedModel: originalModel,
  312. sessionId: "sess-lease-test-2",
  313. messageId: 5002,
  314. });
  315. const response = createStreamResponse(usage);
  316. const clientResponse = await ProxyResponseHandler.dispatch(session, response);
  317. await clientResponse.text();
  318. await drainAsyncTasks();
  319. expect(RateLimitService.decrementLeaseBudget).toHaveBeenCalled();
  320. const calls = vi.mocked(RateLimitService.decrementLeaseBudget).mock.calls;
  321. // Should have exactly 12 calls (4 windows x 3 entity types)
  322. expect(calls.length).toBe(12);
  323. });
  324. it("should NOT call decrementLeaseBudget when cost is zero", async () => {
  325. // Mock price data that results in zero cost
  326. const zeroPriceData: ModelPriceData = {
  327. input_cost_per_token: 0,
  328. output_cost_per_token: 0,
  329. };
  330. vi.mocked(findLatestPriceByModel).mockResolvedValue(
  331. makePriceRecord(originalModel, zeroPriceData)
  332. );
  333. const session = createSession({
  334. originalModel,
  335. redirectedModel: originalModel,
  336. sessionId: "sess-lease-test-3",
  337. messageId: 5003,
  338. });
  339. // Override getResolvedPricingByBillingSource to return zero prices
  340. (
  341. session as {
  342. getResolvedPricingByBillingSource: () => Promise<{
  343. resolvedModelName: string;
  344. resolvedPricingProviderKey: string;
  345. source: string;
  346. priceData: ModelPriceData;
  347. }>;
  348. }
  349. ).getResolvedPricingByBillingSource = async () => ({
  350. resolvedModelName: originalModel,
  351. resolvedPricingProviderKey: "test-provider",
  352. source: "cloud_exact" as const,
  353. priceData: zeroPriceData,
  354. });
  355. const response = createNonStreamResponse(usage);
  356. await ProxyResponseHandler.dispatch(session, response);
  357. await drainAsyncTasks();
  358. // Zero cost should NOT trigger decrement
  359. expect(RateLimitService.decrementLeaseBudget).not.toHaveBeenCalled();
  360. });
  361. it("should call decrementLeaseBudget exactly once per request (no duplicates)", async () => {
  362. const session = createSession({
  363. originalModel,
  364. redirectedModel: originalModel,
  365. sessionId: "sess-lease-test-4",
  366. messageId: 5004,
  367. });
  368. const response = createNonStreamResponse(usage);
  369. await ProxyResponseHandler.dispatch(session, response);
  370. await drainAsyncTasks();
  371. // Each window/entity combo should be called exactly once
  372. const calls = vi.mocked(RateLimitService.decrementLeaseBudget).mock.calls;
  373. // Create a unique key for each call to check for duplicates
  374. const callKeys = calls.map((call) => `${call[0]}-${call[1]}-${call[2]}`);
  375. const uniqueKeys = new Set(callKeys);
  376. // No duplicates: unique keys should equal total calls
  377. expect(uniqueKeys.size).toBe(calls.length);
  378. expect(calls.length).toBe(12); // 4 windows x 3 entities
  379. });
  380. it("should use correct entity IDs from session", async () => {
  381. const customKeyId = 789;
  382. const customUserId = 321;
  383. const customProviderId = 111;
  384. const session = createSession({
  385. originalModel,
  386. redirectedModel: originalModel,
  387. sessionId: "sess-lease-test-5",
  388. messageId: 5005,
  389. });
  390. // Override with custom IDs
  391. session.setProvider({
  392. id: customProviderId,
  393. name: "custom-provider",
  394. providerType: "claude",
  395. costMultiplier: 1.0,
  396. dailyResetTime: "00:00",
  397. dailyResetMode: "fixed",
  398. } as unknown);
  399. session.setAuthState({
  400. user: {
  401. id: customUserId,
  402. name: "custom-user",
  403. dailyResetTime: "00:00",
  404. dailyResetMode: "fixed",
  405. },
  406. key: {
  407. id: customKeyId,
  408. name: "custom-key",
  409. dailyResetTime: "00:00",
  410. dailyResetMode: "fixed",
  411. },
  412. apiKey: "sk-custom",
  413. success: true,
  414. });
  415. session.setMessageContext({
  416. id: 5005,
  417. createdAt: new Date(),
  418. user: {
  419. id: customUserId,
  420. name: "custom-user",
  421. dailyResetTime: "00:00",
  422. dailyResetMode: "fixed",
  423. },
  424. key: {
  425. id: customKeyId,
  426. name: "custom-key",
  427. dailyResetTime: "00:00",
  428. dailyResetMode: "fixed",
  429. },
  430. apiKey: "sk-custom",
  431. });
  432. const response = createNonStreamResponse(usage);
  433. await ProxyResponseHandler.dispatch(session, response);
  434. await drainAsyncTasks();
  435. const calls = vi.mocked(RateLimitService.decrementLeaseBudget).mock.calls;
  436. // Verify key ID
  437. const keyCalls = calls.filter((c) => c[1] === "key");
  438. expect(keyCalls.every((c) => c[0] === customKeyId)).toBe(true);
  439. expect(keyCalls.length).toBe(4);
  440. // Verify user ID
  441. const userCalls = calls.filter((c) => c[1] === "user");
  442. expect(userCalls.every((c) => c[0] === customUserId)).toBe(true);
  443. expect(userCalls.length).toBe(4);
  444. // Verify provider ID
  445. const providerCalls = calls.filter((c) => c[1] === "provider");
  446. expect(providerCalls.every((c) => c[0] === customProviderId)).toBe(true);
  447. expect(providerCalls.length).toBe(4);
  448. });
  449. it("should use fire-and-forget pattern (not block on decrement failures)", async () => {
  450. // Mock decrementLeaseBudget to fail
  451. vi.mocked(RateLimitService.decrementLeaseBudget).mockRejectedValue(
  452. new Error("Redis connection failed")
  453. );
  454. const session = createSession({
  455. originalModel,
  456. redirectedModel: originalModel,
  457. sessionId: "sess-lease-test-6",
  458. messageId: 5006,
  459. });
  460. const response = createNonStreamResponse(usage);
  461. // Should NOT throw even if decrementLeaseBudget fails
  462. await expect(ProxyResponseHandler.dispatch(session, response)).resolves.toBeDefined();
  463. await drainAsyncTasks();
  464. // Verify decrement was attempted
  465. expect(RateLimitService.decrementLeaseBudget).toHaveBeenCalled();
  466. });
  467. });