response-handler-lease-decrement.test.ts 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515
  1. /**
  2. * TDD: RED Phase - Tests for lease budget decrement in response-handler.ts
  3. *
  4. * Tests that decrementLeaseBudget is called correctly after trackCostToRedis completes.
  5. * - All windows: 5h, daily, weekly, monthly
  6. * - All entity types: key, user, provider
  7. * - Zero-cost requests should NOT trigger decrement
  8. * - Function runs once per request (no duplicates)
  9. */
  10. import { beforeEach, describe, expect, it, vi } from "vitest";
  11. import { resolveEndpointPolicy } from "@/app/v1/_lib/proxy/endpoint-policy";
  12. import type { ModelPriceData } from "@/types/model-price";
  13. // Track async tasks for draining
  14. const asyncTasks: Promise<void>[] = [];
  15. vi.mock("@/lib/async-task-manager", () => ({
  16. AsyncTaskManager: {
  17. register: (_taskId: string, promise: Promise<void>) => {
  18. asyncTasks.push(promise);
  19. return new AbortController();
  20. },
  21. cleanup: () => {},
  22. cancel: () => {},
  23. },
  24. }));
  25. vi.mock("@/lib/logger", () => ({
  26. logger: {
  27. debug: () => {},
  28. info: () => {},
  29. warn: () => {},
  30. error: () => {},
  31. trace: () => {},
  32. },
  33. }));
  34. vi.mock("@/lib/price-sync/cloud-price-updater", () => ({
  35. requestCloudPriceTableSync: () => {},
  36. }));
  37. vi.mock("@/repository/model-price", () => ({
  38. findLatestPriceByModel: vi.fn(),
  39. }));
  40. vi.mock("@/repository/system-config", () => ({
  41. getSystemSettings: vi.fn(),
  42. }));
  43. vi.mock("@/repository/message", () => ({
  44. updateMessageRequestCost: vi.fn(),
  45. updateMessageRequestDetails: vi.fn(),
  46. updateMessageRequestDuration: vi.fn(),
  47. }));
  48. vi.mock("@/lib/session-manager", () => ({
  49. SessionManager: {
  50. updateSessionUsage: vi.fn(),
  51. storeSessionResponse: vi.fn(),
  52. extractCodexPromptCacheKey: vi.fn(),
  53. updateSessionWithCodexCacheKey: vi.fn(),
  54. },
  55. }));
  56. vi.mock("@/lib/rate-limit", () => ({
  57. RateLimitService: {
  58. trackCost: vi.fn(),
  59. trackUserDailyCost: vi.fn(),
  60. decrementLeaseBudget: vi.fn(),
  61. },
  62. }));
  63. vi.mock("@/lib/session-tracker", () => ({
  64. SessionTracker: {
  65. refreshSession: vi.fn(),
  66. },
  67. }));
  68. vi.mock("@/lib/proxy-status-tracker", () => ({
  69. ProxyStatusTracker: {
  70. getInstance: () => ({
  71. endRequest: () => {},
  72. }),
  73. },
  74. }));
  75. import { ProxyResponseHandler } from "@/app/v1/_lib/proxy/response-handler";
  76. import { ProxySession } from "@/app/v1/_lib/proxy/session";
  77. import { SessionManager } from "@/lib/session-manager";
  78. import { RateLimitService } from "@/lib/rate-limit";
  79. import { SessionTracker } from "@/lib/session-tracker";
  80. import {
  81. updateMessageRequestCost,
  82. updateMessageRequestDetails,
  83. updateMessageRequestDuration,
  84. } from "@/repository/message";
  85. import { findLatestPriceByModel } from "@/repository/model-price";
  86. import { getSystemSettings } from "@/repository/system-config";
  87. // Test price data
  88. const testPriceData: ModelPriceData = {
  89. input_cost_per_token: 0.000003,
  90. output_cost_per_token: 0.000015,
  91. };
  92. function makePriceRecord(modelName: string, priceData: ModelPriceData) {
  93. return {
  94. id: 1,
  95. modelName,
  96. priceData,
  97. createdAt: new Date(),
  98. updatedAt: new Date(),
  99. };
  100. }
  101. function makeSystemSettings(billingModelSource: "original" | "redirected" = "original") {
  102. return {
  103. billingModelSource,
  104. streamBufferEnabled: false,
  105. streamBufferMode: "none",
  106. streamBufferSize: 0,
  107. } as ReturnType<typeof getSystemSettings> extends Promise<infer T> ? T : never;
  108. }
  109. function createSession(opts: {
  110. originalModel: string;
  111. redirectedModel: string;
  112. sessionId: string;
  113. messageId: number;
  114. }): ProxySession {
  115. const { originalModel, redirectedModel, sessionId, messageId } = opts;
  116. const session = Object.create(ProxySession.prototype) as ProxySession;
  117. Object.assign(session, {
  118. request: { message: {}, log: "(test)", model: redirectedModel },
  119. startTime: Date.now(),
  120. method: "POST",
  121. requestUrl: new URL("http://localhost/v1/messages"),
  122. headers: new Headers(),
  123. headerLog: "",
  124. userAgent: null,
  125. context: {},
  126. clientAbortSignal: null,
  127. userName: "test-user",
  128. authState: null,
  129. provider: null,
  130. messageContext: null,
  131. sessionId: null,
  132. requestSequence: 1,
  133. originalFormat: "claude",
  134. providerType: null,
  135. originalModelName: null,
  136. originalUrlPathname: null,
  137. providerChain: [],
  138. cacheTtlResolved: null,
  139. context1mApplied: false,
  140. specialSettings: [],
  141. cachedPriceData: undefined,
  142. cachedBillingModelSource: undefined,
  143. endpointPolicy: resolveEndpointPolicy("/v1/messages"),
  144. isHeaderModified: () => false,
  145. getContext1mApplied: () => false,
  146. getOriginalModel: () => originalModel,
  147. getCurrentModel: () => redirectedModel,
  148. getProviderChain: () => [],
  149. getCachedPriceDataByBillingSource: async () => testPriceData,
  150. recordTtfb: () => 100,
  151. ttfbMs: null,
  152. getRequestSequence: () => 1,
  153. });
  154. (session as { setOriginalModel(m: string | null): void }).setOriginalModel = function (
  155. m: string | null
  156. ) {
  157. (this as { originalModelName: string | null }).originalModelName = m;
  158. };
  159. (session as { setSessionId(s: string): void }).setSessionId = function (s: string) {
  160. (this as { sessionId: string | null }).sessionId = s;
  161. };
  162. (session as { setProvider(p: unknown): void }).setProvider = function (p: unknown) {
  163. (this as { provider: unknown }).provider = p;
  164. };
  165. (session as { setAuthState(a: unknown): void }).setAuthState = function (a: unknown) {
  166. (this as { authState: unknown }).authState = a;
  167. };
  168. (session as { setMessageContext(c: unknown): void }).setMessageContext = function (c: unknown) {
  169. (this as { messageContext: unknown }).messageContext = c;
  170. };
  171. session.setOriginalModel(originalModel);
  172. session.setSessionId(sessionId);
  173. const provider = {
  174. id: 99,
  175. name: "test-provider",
  176. providerType: "claude",
  177. costMultiplier: 1.0,
  178. streamingIdleTimeoutMs: 0,
  179. dailyResetTime: "00:00",
  180. dailyResetMode: "fixed",
  181. } as unknown;
  182. const user = {
  183. id: 123,
  184. name: "test-user",
  185. dailyResetTime: "00:00",
  186. dailyResetMode: "fixed",
  187. } as unknown;
  188. const key = {
  189. id: 456,
  190. name: "test-key",
  191. dailyResetTime: "00:00",
  192. dailyResetMode: "fixed",
  193. } as unknown;
  194. session.setProvider(provider);
  195. session.setAuthState({
  196. user,
  197. key,
  198. apiKey: "sk-test",
  199. success: true,
  200. });
  201. session.setMessageContext({
  202. id: messageId,
  203. createdAt: new Date(),
  204. user,
  205. key,
  206. apiKey: "sk-test",
  207. });
  208. return session;
  209. }
  210. function createNonStreamResponse(usage: { input_tokens: number; output_tokens: number }): Response {
  211. return new Response(
  212. JSON.stringify({
  213. type: "message",
  214. usage,
  215. }),
  216. {
  217. status: 200,
  218. headers: { "content-type": "application/json" },
  219. }
  220. );
  221. }
  222. function createStreamResponse(usage: { input_tokens: number; output_tokens: number }): Response {
  223. const sseText = `event: message_delta\ndata: ${JSON.stringify({ usage })}\n\n`;
  224. const encoder = new TextEncoder();
  225. const stream = new ReadableStream<Uint8Array>({
  226. start(controller) {
  227. controller.enqueue(encoder.encode(sseText));
  228. controller.close();
  229. },
  230. });
  231. return new Response(stream, {
  232. status: 200,
  233. headers: { "content-type": "text/event-stream" },
  234. });
  235. }
  236. async function drainAsyncTasks(): Promise<void> {
  237. const tasks = asyncTasks.splice(0, asyncTasks.length);
  238. await Promise.all(tasks);
  239. }
  240. beforeEach(() => {
  241. vi.clearAllMocks();
  242. asyncTasks.splice(0, asyncTasks.length);
  243. });
  244. describe("Lease Budget Decrement after trackCostToRedis", () => {
  245. const originalModel = "claude-sonnet-4-20250514";
  246. const usage = { input_tokens: 1000, output_tokens: 500 };
  247. beforeEach(async () => {
  248. vi.mocked(getSystemSettings).mockResolvedValue(makeSystemSettings("original"));
  249. vi.mocked(findLatestPriceByModel).mockResolvedValue(
  250. makePriceRecord(originalModel, testPriceData)
  251. );
  252. vi.mocked(updateMessageRequestDetails).mockResolvedValue(undefined);
  253. vi.mocked(updateMessageRequestDuration).mockResolvedValue(undefined);
  254. vi.mocked(SessionManager.storeSessionResponse).mockResolvedValue(undefined);
  255. vi.mocked(RateLimitService.trackCost).mockResolvedValue(undefined);
  256. vi.mocked(RateLimitService.trackUserDailyCost).mockResolvedValue(undefined);
  257. vi.mocked(RateLimitService.decrementLeaseBudget).mockResolvedValue({
  258. success: true,
  259. newRemaining: 10,
  260. });
  261. vi.mocked(SessionTracker.refreshSession).mockResolvedValue(undefined);
  262. });
  263. it("should call decrementLeaseBudget for all windows and entity types (non-stream)", async () => {
  264. const session = createSession({
  265. originalModel,
  266. redirectedModel: originalModel,
  267. sessionId: "sess-lease-test-1",
  268. messageId: 5001,
  269. });
  270. const response = createNonStreamResponse(usage);
  271. await ProxyResponseHandler.dispatch(session, response);
  272. await drainAsyncTasks();
  273. // Expected cost: (1000 * 0.000003) + (500 * 0.000015) = 0.003 + 0.0075 = 0.0105
  274. const expectedCost = 0.0105;
  275. // Should be called 12 times:
  276. // 4 windows x 3 entity types = 12 calls
  277. // Windows: 5h, daily, weekly, monthly
  278. // Entity types: key(456), user(123), provider(99)
  279. expect(RateLimitService.decrementLeaseBudget).toHaveBeenCalled();
  280. const calls = vi.mocked(RateLimitService.decrementLeaseBudget).mock.calls;
  281. expect(calls.length).toBe(12);
  282. // Verify all windows are covered for each entity type
  283. const windows = ["5h", "daily", "weekly", "monthly"];
  284. const entities = [
  285. { id: 456, type: "key" },
  286. { id: 123, type: "user" },
  287. { id: 99, type: "provider" },
  288. ];
  289. for (const entity of entities) {
  290. for (const window of windows) {
  291. const matchingCall = calls.find(
  292. (call) => call[0] === entity.id && call[1] === entity.type && call[2] === window
  293. );
  294. expect(matchingCall).toBeDefined();
  295. // Cost should be approximately 0.0105
  296. expect(matchingCall![3]).toBeCloseTo(expectedCost, 4);
  297. }
  298. }
  299. });
  300. it("should call decrementLeaseBudget for all windows and entity types (stream)", async () => {
  301. const session = createSession({
  302. originalModel,
  303. redirectedModel: originalModel,
  304. sessionId: "sess-lease-test-2",
  305. messageId: 5002,
  306. });
  307. const response = createStreamResponse(usage);
  308. const clientResponse = await ProxyResponseHandler.dispatch(session, response);
  309. await clientResponse.text();
  310. await drainAsyncTasks();
  311. expect(RateLimitService.decrementLeaseBudget).toHaveBeenCalled();
  312. const calls = vi.mocked(RateLimitService.decrementLeaseBudget).mock.calls;
  313. // Should have exactly 12 calls (4 windows x 3 entity types)
  314. expect(calls.length).toBe(12);
  315. });
  316. it("should NOT call decrementLeaseBudget when cost is zero", async () => {
  317. // Mock price data that results in zero cost
  318. const zeroPriceData: ModelPriceData = {
  319. input_cost_per_token: 0,
  320. output_cost_per_token: 0,
  321. };
  322. vi.mocked(findLatestPriceByModel).mockResolvedValue(
  323. makePriceRecord(originalModel, zeroPriceData)
  324. );
  325. const session = createSession({
  326. originalModel,
  327. redirectedModel: originalModel,
  328. sessionId: "sess-lease-test-3",
  329. messageId: 5003,
  330. });
  331. // Override getCachedPriceDataByBillingSource to return zero prices
  332. (
  333. session as { getCachedPriceDataByBillingSource: () => Promise<ModelPriceData> }
  334. ).getCachedPriceDataByBillingSource = async () => zeroPriceData;
  335. const response = createNonStreamResponse(usage);
  336. await ProxyResponseHandler.dispatch(session, response);
  337. await drainAsyncTasks();
  338. // Zero cost should NOT trigger decrement
  339. expect(RateLimitService.decrementLeaseBudget).not.toHaveBeenCalled();
  340. });
  341. it("should call decrementLeaseBudget exactly once per request (no duplicates)", async () => {
  342. const session = createSession({
  343. originalModel,
  344. redirectedModel: originalModel,
  345. sessionId: "sess-lease-test-4",
  346. messageId: 5004,
  347. });
  348. const response = createNonStreamResponse(usage);
  349. await ProxyResponseHandler.dispatch(session, response);
  350. await drainAsyncTasks();
  351. // Each window/entity combo should be called exactly once
  352. const calls = vi.mocked(RateLimitService.decrementLeaseBudget).mock.calls;
  353. // Create a unique key for each call to check for duplicates
  354. const callKeys = calls.map((call) => `${call[0]}-${call[1]}-${call[2]}`);
  355. const uniqueKeys = new Set(callKeys);
  356. // No duplicates: unique keys should equal total calls
  357. expect(uniqueKeys.size).toBe(calls.length);
  358. expect(calls.length).toBe(12); // 4 windows x 3 entities
  359. });
  360. it("should use correct entity IDs from session", async () => {
  361. const customKeyId = 789;
  362. const customUserId = 321;
  363. const customProviderId = 111;
  364. const session = createSession({
  365. originalModel,
  366. redirectedModel: originalModel,
  367. sessionId: "sess-lease-test-5",
  368. messageId: 5005,
  369. });
  370. // Override with custom IDs
  371. session.setProvider({
  372. id: customProviderId,
  373. name: "custom-provider",
  374. providerType: "claude",
  375. costMultiplier: 1.0,
  376. dailyResetTime: "00:00",
  377. dailyResetMode: "fixed",
  378. } as unknown);
  379. session.setAuthState({
  380. user: {
  381. id: customUserId,
  382. name: "custom-user",
  383. dailyResetTime: "00:00",
  384. dailyResetMode: "fixed",
  385. },
  386. key: {
  387. id: customKeyId,
  388. name: "custom-key",
  389. dailyResetTime: "00:00",
  390. dailyResetMode: "fixed",
  391. },
  392. apiKey: "sk-custom",
  393. success: true,
  394. });
  395. session.setMessageContext({
  396. id: 5005,
  397. createdAt: new Date(),
  398. user: {
  399. id: customUserId,
  400. name: "custom-user",
  401. dailyResetTime: "00:00",
  402. dailyResetMode: "fixed",
  403. },
  404. key: {
  405. id: customKeyId,
  406. name: "custom-key",
  407. dailyResetTime: "00:00",
  408. dailyResetMode: "fixed",
  409. },
  410. apiKey: "sk-custom",
  411. });
  412. const response = createNonStreamResponse(usage);
  413. await ProxyResponseHandler.dispatch(session, response);
  414. await drainAsyncTasks();
  415. const calls = vi.mocked(RateLimitService.decrementLeaseBudget).mock.calls;
  416. // Verify key ID
  417. const keyCalls = calls.filter((c) => c[1] === "key");
  418. expect(keyCalls.every((c) => c[0] === customKeyId)).toBe(true);
  419. expect(keyCalls.length).toBe(4);
  420. // Verify user ID
  421. const userCalls = calls.filter((c) => c[1] === "user");
  422. expect(userCalls.every((c) => c[0] === customUserId)).toBe(true);
  423. expect(userCalls.length).toBe(4);
  424. // Verify provider ID
  425. const providerCalls = calls.filter((c) => c[1] === "provider");
  426. expect(providerCalls.every((c) => c[0] === customProviderId)).toBe(true);
  427. expect(providerCalls.length).toBe(4);
  428. });
  429. it("should use fire-and-forget pattern (not block on decrement failures)", async () => {
  430. // Mock decrementLeaseBudget to fail
  431. vi.mocked(RateLimitService.decrementLeaseBudget).mockRejectedValue(
  432. new Error("Redis connection failed")
  433. );
  434. const session = createSession({
  435. originalModel,
  436. redirectedModel: originalModel,
  437. sessionId: "sess-lease-test-6",
  438. messageId: 5006,
  439. });
  440. const response = createNonStreamResponse(usage);
  441. // Should NOT throw even if decrementLeaseBudget fails
  442. await expect(ProxyResponseHandler.dispatch(session, response)).resolves.toBeDefined();
  443. await drainAsyncTasks();
  444. // Verify decrement was attempted
  445. expect(RateLimitService.decrementLeaseBudget).toHaveBeenCalled();
  446. });
  447. });