response-handler-lease-decrement.test.ts 15 KB


  1. /**
  2. * TDD: RED Phase - Tests for lease budget decrement in response-handler.ts
  3. *
  4. * Tests that decrementLeaseBudget is called correctly after trackCostToRedis completes.
  5. * - All windows: 5h, daily, weekly, monthly
  6. * - All entity types: key, user, provider
  7. * - Zero-cost requests should NOT trigger decrement
  8. * - Function runs once per request (no duplicates)
  9. */
  10. import { beforeEach, describe, expect, it, vi } from "vitest";
  11. import type { ModelPriceData } from "@/types/model-price";
  12. // Track async tasks for draining
  13. const asyncTasks: Promise<void>[] = [];
  14. vi.mock("@/lib/async-task-manager", () => ({
  15. AsyncTaskManager: {
  16. register: (_taskId: string, promise: Promise<void>) => {
  17. asyncTasks.push(promise);
  18. return new AbortController();
  19. },
  20. cleanup: () => {},
  21. cancel: () => {},
  22. },
  23. }));
  24. vi.mock("@/lib/logger", () => ({
  25. logger: {
  26. debug: () => {},
  27. info: () => {},
  28. warn: () => {},
  29. error: () => {},
  30. trace: () => {},
  31. },
  32. }));
  33. vi.mock("@/lib/price-sync/cloud-price-updater", () => ({
  34. requestCloudPriceTableSync: () => {},
  35. }));
  36. vi.mock("@/repository/model-price", () => ({
  37. findLatestPriceByModel: vi.fn(),
  38. }));
  39. vi.mock("@/repository/system-config", () => ({
  40. getSystemSettings: vi.fn(),
  41. }));
  42. vi.mock("@/repository/message", () => ({
  43. updateMessageRequestCost: vi.fn(),
  44. updateMessageRequestDetails: vi.fn(),
  45. updateMessageRequestDuration: vi.fn(),
  46. }));
  47. vi.mock("@/lib/session-manager", () => ({
  48. SessionManager: {
  49. updateSessionUsage: vi.fn(),
  50. storeSessionResponse: vi.fn(),
  51. extractCodexPromptCacheKey: vi.fn(),
  52. updateSessionWithCodexCacheKey: vi.fn(),
  53. },
  54. }));
  55. vi.mock("@/lib/rate-limit", () => ({
  56. RateLimitService: {
  57. trackCost: vi.fn(),
  58. trackUserDailyCost: vi.fn(),
  59. decrementLeaseBudget: vi.fn(),
  60. },
  61. }));
  62. vi.mock("@/lib/session-tracker", () => ({
  63. SessionTracker: {
  64. refreshSession: vi.fn(),
  65. },
  66. }));
  67. vi.mock("@/lib/proxy-status-tracker", () => ({
  68. ProxyStatusTracker: {
  69. getInstance: () => ({
  70. endRequest: () => {},
  71. }),
  72. },
  73. }));
  74. import { ProxyResponseHandler } from "@/app/v1/_lib/proxy/response-handler";
  75. import { ProxySession } from "@/app/v1/_lib/proxy/session";
  76. import { SessionManager } from "@/lib/session-manager";
  77. import { RateLimitService } from "@/lib/rate-limit";
  78. import { SessionTracker } from "@/lib/session-tracker";
  79. import {
  80. updateMessageRequestCost,
  81. updateMessageRequestDetails,
  82. updateMessageRequestDuration,
  83. } from "@/repository/message";
  84. import { findLatestPriceByModel } from "@/repository/model-price";
  85. import { getSystemSettings } from "@/repository/system-config";
  86. // Test price data
  87. const testPriceData: ModelPriceData = {
  88. input_cost_per_token: 0.000003,
  89. output_cost_per_token: 0.000015,
  90. };
  91. function makePriceRecord(modelName: string, priceData: ModelPriceData) {
  92. return {
  93. id: 1,
  94. modelName,
  95. priceData,
  96. createdAt: new Date(),
  97. updatedAt: new Date(),
  98. };
  99. }
  100. function makeSystemSettings(billingModelSource: "original" | "redirected" = "original") {
  101. return {
  102. billingModelSource,
  103. streamBufferEnabled: false,
  104. streamBufferMode: "none",
  105. streamBufferSize: 0,
  106. } as ReturnType<typeof getSystemSettings> extends Promise<infer T> ? T : never;
  107. }
  108. function createSession(opts: {
  109. originalModel: string;
  110. redirectedModel: string;
  111. sessionId: string;
  112. messageId: number;
  113. }): ProxySession {
  114. const { originalModel, redirectedModel, sessionId, messageId } = opts;
  115. const session = Object.create(ProxySession.prototype) as ProxySession;
  116. Object.assign(session, {
  117. request: { message: {}, log: "(test)", model: redirectedModel },
  118. startTime: Date.now(),
  119. method: "POST",
  120. requestUrl: new URL("http://localhost/v1/messages"),
  121. headers: new Headers(),
  122. headerLog: "",
  123. userAgent: null,
  124. context: {},
  125. clientAbortSignal: null,
  126. userName: "test-user",
  127. authState: null,
  128. provider: null,
  129. messageContext: null,
  130. sessionId: null,
  131. requestSequence: 1,
  132. originalFormat: "claude",
  133. providerType: null,
  134. originalModelName: null,
  135. originalUrlPathname: null,
  136. providerChain: [],
  137. cacheTtlResolved: null,
  138. context1mApplied: false,
  139. specialSettings: [],
  140. cachedPriceData: undefined,
  141. cachedBillingModelSource: undefined,
  142. isHeaderModified: () => false,
  143. getContext1mApplied: () => false,
  144. getOriginalModel: () => originalModel,
  145. getCurrentModel: () => redirectedModel,
  146. getProviderChain: () => [],
  147. getCachedPriceDataByBillingSource: async () => testPriceData,
  148. recordTtfb: () => 100,
  149. ttfbMs: null,
  150. getRequestSequence: () => 1,
  151. });
  152. (session as { setOriginalModel(m: string | null): void }).setOriginalModel = function (
  153. m: string | null
  154. ) {
  155. (this as { originalModelName: string | null }).originalModelName = m;
  156. };
  157. (session as { setSessionId(s: string): void }).setSessionId = function (s: string) {
  158. (this as { sessionId: string | null }).sessionId = s;
  159. };
  160. (session as { setProvider(p: unknown): void }).setProvider = function (p: unknown) {
  161. (this as { provider: unknown }).provider = p;
  162. };
  163. (session as { setAuthState(a: unknown): void }).setAuthState = function (a: unknown) {
  164. (this as { authState: unknown }).authState = a;
  165. };
  166. (session as { setMessageContext(c: unknown): void }).setMessageContext = function (c: unknown) {
  167. (this as { messageContext: unknown }).messageContext = c;
  168. };
  169. session.setOriginalModel(originalModel);
  170. session.setSessionId(sessionId);
  171. const provider = {
  172. id: 99,
  173. name: "test-provider",
  174. providerType: "claude",
  175. costMultiplier: 1.0,
  176. streamingIdleTimeoutMs: 0,
  177. dailyResetTime: "00:00",
  178. dailyResetMode: "fixed",
  179. } as unknown;
  180. const user = {
  181. id: 123,
  182. name: "test-user",
  183. dailyResetTime: "00:00",
  184. dailyResetMode: "fixed",
  185. } as unknown;
  186. const key = {
  187. id: 456,
  188. name: "test-key",
  189. dailyResetTime: "00:00",
  190. dailyResetMode: "fixed",
  191. } as unknown;
  192. session.setProvider(provider);
  193. session.setAuthState({
  194. user,
  195. key,
  196. apiKey: "sk-test",
  197. success: true,
  198. });
  199. session.setMessageContext({
  200. id: messageId,
  201. createdAt: new Date(),
  202. user,
  203. key,
  204. apiKey: "sk-test",
  205. });
  206. return session;
  207. }
  208. function createNonStreamResponse(usage: { input_tokens: number; output_tokens: number }): Response {
  209. return new Response(
  210. JSON.stringify({
  211. type: "message",
  212. usage,
  213. }),
  214. {
  215. status: 200,
  216. headers: { "content-type": "application/json" },
  217. }
  218. );
  219. }
  220. function createStreamResponse(usage: { input_tokens: number; output_tokens: number }): Response {
  221. const sseText = `event: message_delta\ndata: ${JSON.stringify({ usage })}\n\n`;
  222. const encoder = new TextEncoder();
  223. const stream = new ReadableStream<Uint8Array>({
  224. start(controller) {
  225. controller.enqueue(encoder.encode(sseText));
  226. controller.close();
  227. },
  228. });
  229. return new Response(stream, {
  230. status: 200,
  231. headers: { "content-type": "text/event-stream" },
  232. });
  233. }
  234. async function drainAsyncTasks(): Promise<void> {
  235. const tasks = asyncTasks.splice(0, asyncTasks.length);
  236. await Promise.all(tasks);
  237. }
  238. beforeEach(() => {
  239. vi.clearAllMocks();
  240. asyncTasks.splice(0, asyncTasks.length);
  241. });
  242. describe("Lease Budget Decrement after trackCostToRedis", () => {
  243. const originalModel = "claude-sonnet-4-20250514";
  244. const usage = { input_tokens: 1000, output_tokens: 500 };
  245. beforeEach(async () => {
  246. vi.mocked(getSystemSettings).mockResolvedValue(makeSystemSettings("original"));
  247. vi.mocked(findLatestPriceByModel).mockResolvedValue(
  248. makePriceRecord(originalModel, testPriceData)
  249. );
  250. vi.mocked(updateMessageRequestDetails).mockResolvedValue(undefined);
  251. vi.mocked(updateMessageRequestDuration).mockResolvedValue(undefined);
  252. vi.mocked(SessionManager.storeSessionResponse).mockResolvedValue(undefined);
  253. vi.mocked(RateLimitService.trackCost).mockResolvedValue(undefined);
  254. vi.mocked(RateLimitService.trackUserDailyCost).mockResolvedValue(undefined);
  255. vi.mocked(RateLimitService.decrementLeaseBudget).mockResolvedValue({
  256. success: true,
  257. newRemaining: 10,
  258. });
  259. vi.mocked(SessionTracker.refreshSession).mockResolvedValue(undefined);
  260. });
  261. it("should call decrementLeaseBudget for all windows and entity types (non-stream)", async () => {
  262. const session = createSession({
  263. originalModel,
  264. redirectedModel: originalModel,
  265. sessionId: "sess-lease-test-1",
  266. messageId: 5001,
  267. });
  268. const response = createNonStreamResponse(usage);
  269. await ProxyResponseHandler.dispatch(session, response);
  270. await drainAsyncTasks();
  271. // Expected cost: (1000 * 0.000003) + (500 * 0.000015) = 0.003 + 0.0075 = 0.0105
  272. const expectedCost = 0.0105;
  273. // Should be called 12 times:
  274. // 4 windows x 3 entity types = 12 calls
  275. // Windows: 5h, daily, weekly, monthly
  276. // Entity types: key(456), user(123), provider(99)
  277. expect(RateLimitService.decrementLeaseBudget).toHaveBeenCalled();
  278. const calls = vi.mocked(RateLimitService.decrementLeaseBudget).mock.calls;
  279. expect(calls.length).toBe(12);
  280. // Verify all windows are covered for each entity type
  281. const windows = ["5h", "daily", "weekly", "monthly"];
  282. const entities = [
  283. { id: 456, type: "key" },
  284. { id: 123, type: "user" },
  285. { id: 99, type: "provider" },
  286. ];
  287. for (const entity of entities) {
  288. for (const window of windows) {
  289. const matchingCall = calls.find(
  290. (call) => call[0] === entity.id && call[1] === entity.type && call[2] === window
  291. );
  292. expect(matchingCall).toBeDefined();
  293. // Cost should be approximately 0.0105
  294. expect(matchingCall![3]).toBeCloseTo(expectedCost, 4);
  295. }
  296. }
  297. });
  298. it("should call decrementLeaseBudget for all windows and entity types (stream)", async () => {
  299. const session = createSession({
  300. originalModel,
  301. redirectedModel: originalModel,
  302. sessionId: "sess-lease-test-2",
  303. messageId: 5002,
  304. });
  305. const response = createStreamResponse(usage);
  306. const clientResponse = await ProxyResponseHandler.dispatch(session, response);
  307. await clientResponse.text();
  308. await drainAsyncTasks();
  309. expect(RateLimitService.decrementLeaseBudget).toHaveBeenCalled();
  310. const calls = vi.mocked(RateLimitService.decrementLeaseBudget).mock.calls;
  311. // Should have exactly 12 calls (4 windows x 3 entity types)
  312. expect(calls.length).toBe(12);
  313. });
  314. it("should NOT call decrementLeaseBudget when cost is zero", async () => {
  315. // Mock price data that results in zero cost
  316. const zeroPriceData: ModelPriceData = {
  317. input_cost_per_token: 0,
  318. output_cost_per_token: 0,
  319. };
  320. vi.mocked(findLatestPriceByModel).mockResolvedValue(
  321. makePriceRecord(originalModel, zeroPriceData)
  322. );
  323. const session = createSession({
  324. originalModel,
  325. redirectedModel: originalModel,
  326. sessionId: "sess-lease-test-3",
  327. messageId: 5003,
  328. });
  329. // Override getCachedPriceDataByBillingSource to return zero prices
  330. (
  331. session as { getCachedPriceDataByBillingSource: () => Promise<ModelPriceData> }
  332. ).getCachedPriceDataByBillingSource = async () => zeroPriceData;
  333. const response = createNonStreamResponse(usage);
  334. await ProxyResponseHandler.dispatch(session, response);
  335. await drainAsyncTasks();
  336. // Zero cost should NOT trigger decrement
  337. expect(RateLimitService.decrementLeaseBudget).not.toHaveBeenCalled();
  338. });
  339. it("should call decrementLeaseBudget exactly once per request (no duplicates)", async () => {
  340. const session = createSession({
  341. originalModel,
  342. redirectedModel: originalModel,
  343. sessionId: "sess-lease-test-4",
  344. messageId: 5004,
  345. });
  346. const response = createNonStreamResponse(usage);
  347. await ProxyResponseHandler.dispatch(session, response);
  348. await drainAsyncTasks();
  349. // Each window/entity combo should be called exactly once
  350. const calls = vi.mocked(RateLimitService.decrementLeaseBudget).mock.calls;
  351. // Create a unique key for each call to check for duplicates
  352. const callKeys = calls.map((call) => `${call[0]}-${call[1]}-${call[2]}`);
  353. const uniqueKeys = new Set(callKeys);
  354. // No duplicates: unique keys should equal total calls
  355. expect(uniqueKeys.size).toBe(calls.length);
  356. expect(calls.length).toBe(12); // 4 windows x 3 entities
  357. });
  358. it("should use correct entity IDs from session", async () => {
  359. const customKeyId = 789;
  360. const customUserId = 321;
  361. const customProviderId = 111;
  362. const session = createSession({
  363. originalModel,
  364. redirectedModel: originalModel,
  365. sessionId: "sess-lease-test-5",
  366. messageId: 5005,
  367. });
  368. // Override with custom IDs
  369. session.setProvider({
  370. id: customProviderId,
  371. name: "custom-provider",
  372. providerType: "claude",
  373. costMultiplier: 1.0,
  374. dailyResetTime: "00:00",
  375. dailyResetMode: "fixed",
  376. } as unknown);
  377. session.setAuthState({
  378. user: {
  379. id: customUserId,
  380. name: "custom-user",
  381. dailyResetTime: "00:00",
  382. dailyResetMode: "fixed",
  383. },
  384. key: {
  385. id: customKeyId,
  386. name: "custom-key",
  387. dailyResetTime: "00:00",
  388. dailyResetMode: "fixed",
  389. },
  390. apiKey: "sk-custom",
  391. success: true,
  392. });
  393. session.setMessageContext({
  394. id: 5005,
  395. createdAt: new Date(),
  396. user: {
  397. id: customUserId,
  398. name: "custom-user",
  399. dailyResetTime: "00:00",
  400. dailyResetMode: "fixed",
  401. },
  402. key: {
  403. id: customKeyId,
  404. name: "custom-key",
  405. dailyResetTime: "00:00",
  406. dailyResetMode: "fixed",
  407. },
  408. apiKey: "sk-custom",
  409. });
  410. const response = createNonStreamResponse(usage);
  411. await ProxyResponseHandler.dispatch(session, response);
  412. await drainAsyncTasks();
  413. const calls = vi.mocked(RateLimitService.decrementLeaseBudget).mock.calls;
  414. // Verify key ID
  415. const keyCalls = calls.filter((c) => c[1] === "key");
  416. expect(keyCalls.every((c) => c[0] === customKeyId)).toBe(true);
  417. expect(keyCalls.length).toBe(4);
  418. // Verify user ID
  419. const userCalls = calls.filter((c) => c[1] === "user");
  420. expect(userCalls.every((c) => c[0] === customUserId)).toBe(true);
  421. expect(userCalls.length).toBe(4);
  422. // Verify provider ID
  423. const providerCalls = calls.filter((c) => c[1] === "provider");
  424. expect(providerCalls.every((c) => c[0] === customProviderId)).toBe(true);
  425. expect(providerCalls.length).toBe(4);
  426. });
  427. it("should use fire-and-forget pattern (not block on decrement failures)", async () => {
  428. // Mock decrementLeaseBudget to fail
  429. vi.mocked(RateLimitService.decrementLeaseBudget).mockRejectedValue(
  430. new Error("Redis connection failed")
  431. );
  432. const session = createSession({
  433. originalModel,
  434. redirectedModel: originalModel,
  435. sessionId: "sess-lease-test-6",
  436. messageId: 5006,
  437. });
  438. const response = createNonStreamResponse(usage);
  439. // Should NOT throw even if decrementLeaseBudget fails
  440. await expect(ProxyResponseHandler.dispatch(session, response)).resolves.toBeDefined();
  441. await drainAsyncTasks();
  442. // Verify decrement was attempted
  443. expect(RateLimitService.decrementLeaseBudget).toHaveBeenCalled();
  444. });
  445. });