response-handler-lease-decrement.test.ts 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533
  1. /**
  2. * TDD: RED Phase - Tests for lease budget decrement in response-handler.ts
  3. *
  4. * Tests that decrementLeaseBudget is called correctly after trackCostToRedis completes.
  5. * - All windows: 5h, daily, weekly, monthly
  6. * - All entity types: key, user, provider
  7. * - Zero-cost requests should NOT trigger decrement
  8. * - Function runs once per request (no duplicates)
  9. */
  10. import { beforeEach, describe, expect, it, vi } from "vitest";
  11. import { resolveEndpointPolicy } from "@/app/v1/_lib/proxy/endpoint-policy";
  12. import type { ModelPriceData } from "@/types/model-price";
  13. // Track async tasks for draining
  14. const asyncTasks: Promise<void>[] = [];
  15. vi.mock("@/lib/async-task-manager", () => ({
  16. AsyncTaskManager: {
  17. register: (_taskId: string, promise: Promise<void>) => {
  18. asyncTasks.push(promise);
  19. return new AbortController();
  20. },
  21. cleanup: () => {},
  22. cancel: () => {},
  23. },
  24. }));
  25. vi.mock("@/lib/logger", () => ({
  26. logger: {
  27. debug: () => {},
  28. info: () => {},
  29. warn: () => {},
  30. error: () => {},
  31. trace: () => {},
  32. },
  33. }));
  34. vi.mock("@/lib/price-sync/cloud-price-updater", () => ({
  35. requestCloudPriceTableSync: () => {},
  36. }));
  37. vi.mock("@/repository/model-price", () => ({
  38. findLatestPriceByModel: vi.fn(),
  39. }));
  40. vi.mock("@/repository/system-config", () => ({
  41. getSystemSettings: vi.fn(),
  42. }));
  43. vi.mock("@/repository/message", () => ({
  44. updateMessageRequestCost: vi.fn(),
  45. updateMessageRequestDetails: vi.fn(),
  46. updateMessageRequestDuration: vi.fn(),
  47. }));
  48. vi.mock("@/lib/session-manager", () => ({
  49. SessionManager: {
  50. updateSessionUsage: vi.fn(),
  51. storeSessionResponse: vi.fn(),
  52. extractCodexPromptCacheKey: vi.fn(),
  53. updateSessionWithCodexCacheKey: vi.fn(),
  54. },
  55. }));
  56. vi.mock("@/lib/rate-limit", () => ({
  57. RateLimitService: {
  58. trackCost: vi.fn(),
  59. trackUserDailyCost: vi.fn(),
  60. decrementLeaseBudget: vi.fn(),
  61. },
  62. }));
  63. vi.mock("@/lib/session-tracker", () => ({
  64. SessionTracker: {
  65. refreshSession: vi.fn(),
  66. },
  67. }));
  68. vi.mock("@/lib/proxy-status-tracker", () => ({
  69. ProxyStatusTracker: {
  70. getInstance: () => ({
  71. endRequest: () => {},
  72. }),
  73. },
  74. }));
  75. import { ProxyResponseHandler } from "@/app/v1/_lib/proxy/response-handler";
  76. import { ProxySession } from "@/app/v1/_lib/proxy/session";
  77. import { SessionManager } from "@/lib/session-manager";
  78. import { RateLimitService } from "@/lib/rate-limit";
  79. import { SessionTracker } from "@/lib/session-tracker";
  80. import {
  81. updateMessageRequestCost,
  82. updateMessageRequestDetails,
  83. updateMessageRequestDuration,
  84. } from "@/repository/message";
  85. import { findLatestPriceByModel } from "@/repository/model-price";
  86. import { getSystemSettings } from "@/repository/system-config";
  87. // Test price data
  88. const testPriceData: ModelPriceData = {
  89. input_cost_per_token: 0.000003,
  90. output_cost_per_token: 0.000015,
  91. };
  92. function makePriceRecord(modelName: string, priceData: ModelPriceData) {
  93. return {
  94. id: 1,
  95. modelName,
  96. priceData,
  97. createdAt: new Date(),
  98. updatedAt: new Date(),
  99. };
  100. }
  101. function makeSystemSettings(billingModelSource: "original" | "redirected" = "original") {
  102. return {
  103. billingModelSource,
  104. streamBufferEnabled: false,
  105. streamBufferMode: "none",
  106. streamBufferSize: 0,
  107. } as ReturnType<typeof getSystemSettings> extends Promise<infer T> ? T : never;
  108. }
  109. function createSession(opts: {
  110. originalModel: string;
  111. redirectedModel: string;
  112. sessionId: string;
  113. messageId: number;
  114. }): ProxySession {
  115. const { originalModel, redirectedModel, sessionId, messageId } = opts;
  116. const session = Object.create(ProxySession.prototype) as ProxySession;
  117. Object.assign(session, {
  118. request: { message: {}, log: "(test)", model: redirectedModel },
  119. startTime: Date.now(),
  120. method: "POST",
  121. requestUrl: new URL("http://localhost/v1/messages"),
  122. headers: new Headers(),
  123. headerLog: "",
  124. userAgent: null,
  125. context: {},
  126. clientAbortSignal: null,
  127. userName: "test-user",
  128. authState: null,
  129. provider: null,
  130. messageContext: null,
  131. sessionId: null,
  132. requestSequence: 1,
  133. originalFormat: "claude",
  134. providerType: null,
  135. originalModelName: null,
  136. originalUrlPathname: null,
  137. providerChain: [],
  138. cacheTtlResolved: null,
  139. context1mApplied: false,
  140. specialSettings: [],
  141. cachedPriceData: undefined,
  142. cachedBillingModelSource: undefined,
  143. resolvedPricingCache: new Map(),
  144. endpointPolicy: resolveEndpointPolicy("/v1/messages"),
  145. isHeaderModified: () => false,
  146. getContext1mApplied: () => false,
  147. getOriginalModel: () => originalModel,
  148. getCurrentModel: () => redirectedModel,
  149. getProviderChain: () => [],
  150. getResolvedPricingByBillingSource: async () => ({
  151. resolvedModelName: redirectedModel,
  152. resolvedPricingProviderKey: "test-provider",
  153. source: "cloud_exact" as const,
  154. priceData: testPriceData,
  155. }),
  156. recordTtfb: () => 100,
  157. ttfbMs: null,
  158. getRequestSequence: () => 1,
  159. });
  160. (session as { setOriginalModel(m: string | null): void }).setOriginalModel = function (
  161. m: string | null
  162. ) {
  163. (this as { originalModelName: string | null }).originalModelName = m;
  164. };
  165. (session as { setSessionId(s: string): void }).setSessionId = function (s: string) {
  166. (this as { sessionId: string | null }).sessionId = s;
  167. };
  168. (session as { setProvider(p: unknown): void }).setProvider = function (p: unknown) {
  169. (this as { provider: unknown }).provider = p;
  170. };
  171. (session as { setAuthState(a: unknown): void }).setAuthState = function (a: unknown) {
  172. (this as { authState: unknown }).authState = a;
  173. };
  174. (session as { setMessageContext(c: unknown): void }).setMessageContext = function (c: unknown) {
  175. (this as { messageContext: unknown }).messageContext = c;
  176. };
  177. session.setOriginalModel(originalModel);
  178. session.setSessionId(sessionId);
  179. const provider = {
  180. id: 99,
  181. name: "test-provider",
  182. providerType: "claude",
  183. costMultiplier: 1.0,
  184. streamingIdleTimeoutMs: 0,
  185. dailyResetTime: "00:00",
  186. dailyResetMode: "fixed",
  187. } as unknown;
  188. const user = {
  189. id: 123,
  190. name: "test-user",
  191. dailyResetTime: "00:00",
  192. dailyResetMode: "fixed",
  193. } as unknown;
  194. const key = {
  195. id: 456,
  196. name: "test-key",
  197. dailyResetTime: "00:00",
  198. dailyResetMode: "fixed",
  199. } as unknown;
  200. session.setProvider(provider);
  201. session.setAuthState({
  202. user,
  203. key,
  204. apiKey: "sk-test",
  205. success: true,
  206. });
  207. session.setMessageContext({
  208. id: messageId,
  209. createdAt: new Date(),
  210. user,
  211. key,
  212. apiKey: "sk-test",
  213. });
  214. return session;
  215. }
  216. function createNonStreamResponse(usage: { input_tokens: number; output_tokens: number }): Response {
  217. return new Response(
  218. JSON.stringify({
  219. type: "message",
  220. usage,
  221. }),
  222. {
  223. status: 200,
  224. headers: { "content-type": "application/json" },
  225. }
  226. );
  227. }
  228. function createStreamResponse(usage: { input_tokens: number; output_tokens: number }): Response {
  229. const sseText = `event: message_delta\ndata: ${JSON.stringify({ usage })}\n\n`;
  230. const encoder = new TextEncoder();
  231. const stream = new ReadableStream<Uint8Array>({
  232. start(controller) {
  233. controller.enqueue(encoder.encode(sseText));
  234. controller.close();
  235. },
  236. });
  237. return new Response(stream, {
  238. status: 200,
  239. headers: { "content-type": "text/event-stream" },
  240. });
  241. }
  242. async function drainAsyncTasks(): Promise<void> {
  243. const tasks = asyncTasks.splice(0, asyncTasks.length);
  244. await Promise.all(tasks);
  245. }
  246. beforeEach(() => {
  247. vi.clearAllMocks();
  248. asyncTasks.splice(0, asyncTasks.length);
  249. });
  250. describe("Lease Budget Decrement after trackCostToRedis", () => {
  251. const originalModel = "claude-sonnet-4-20250514";
  252. const usage = { input_tokens: 1000, output_tokens: 500 };
  253. beforeEach(async () => {
  254. vi.mocked(getSystemSettings).mockResolvedValue(makeSystemSettings("original"));
  255. vi.mocked(findLatestPriceByModel).mockResolvedValue(
  256. makePriceRecord(originalModel, testPriceData)
  257. );
  258. vi.mocked(updateMessageRequestDetails).mockResolvedValue(undefined);
  259. vi.mocked(updateMessageRequestDuration).mockResolvedValue(undefined);
  260. vi.mocked(SessionManager.storeSessionResponse).mockResolvedValue(undefined);
  261. vi.mocked(RateLimitService.trackCost).mockResolvedValue(undefined);
  262. vi.mocked(RateLimitService.trackUserDailyCost).mockResolvedValue(undefined);
  263. vi.mocked(RateLimitService.decrementLeaseBudget).mockResolvedValue({
  264. success: true,
  265. newRemaining: 10,
  266. });
  267. vi.mocked(SessionTracker.refreshSession).mockResolvedValue(undefined);
  268. });
  269. it("should call decrementLeaseBudget for all windows and entity types (non-stream)", async () => {
  270. const session = createSession({
  271. originalModel,
  272. redirectedModel: originalModel,
  273. sessionId: "sess-lease-test-1",
  274. messageId: 5001,
  275. });
  276. const response = createNonStreamResponse(usage);
  277. await ProxyResponseHandler.dispatch(session, response);
  278. await drainAsyncTasks();
  279. // Expected cost: (1000 * 0.000003) + (500 * 0.000015) = 0.003 + 0.0075 = 0.0105
  280. const expectedCost = 0.0105;
  281. // Should be called 12 times:
  282. // 4 windows x 3 entity types = 12 calls
  283. // Windows: 5h, daily, weekly, monthly
  284. // Entity types: key(456), user(123), provider(99)
  285. expect(RateLimitService.decrementLeaseBudget).toHaveBeenCalled();
  286. const calls = vi.mocked(RateLimitService.decrementLeaseBudget).mock.calls;
  287. expect(calls.length).toBe(12);
  288. // Verify all windows are covered for each entity type
  289. const windows = ["5h", "daily", "weekly", "monthly"];
  290. const entities = [
  291. { id: 456, type: "key" },
  292. { id: 123, type: "user" },
  293. { id: 99, type: "provider" },
  294. ];
  295. for (const entity of entities) {
  296. for (const window of windows) {
  297. const matchingCall = calls.find(
  298. (call) => call[0] === entity.id && call[1] === entity.type && call[2] === window
  299. );
  300. expect(matchingCall).toBeDefined();
  301. // Cost should be approximately 0.0105
  302. expect(matchingCall![3]).toBeCloseTo(expectedCost, 4);
  303. }
  304. }
  305. });
  306. it("should call decrementLeaseBudget for all windows and entity types (stream)", async () => {
  307. const session = createSession({
  308. originalModel,
  309. redirectedModel: originalModel,
  310. sessionId: "sess-lease-test-2",
  311. messageId: 5002,
  312. });
  313. const response = createStreamResponse(usage);
  314. const clientResponse = await ProxyResponseHandler.dispatch(session, response);
  315. await clientResponse.text();
  316. await drainAsyncTasks();
  317. expect(RateLimitService.decrementLeaseBudget).toHaveBeenCalled();
  318. const calls = vi.mocked(RateLimitService.decrementLeaseBudget).mock.calls;
  319. // Should have exactly 12 calls (4 windows x 3 entity types)
  320. expect(calls.length).toBe(12);
  321. });
  322. it("should NOT call decrementLeaseBudget when cost is zero", async () => {
  323. // Mock price data that results in zero cost
  324. const zeroPriceData: ModelPriceData = {
  325. input_cost_per_token: 0,
  326. output_cost_per_token: 0,
  327. };
  328. vi.mocked(findLatestPriceByModel).mockResolvedValue(
  329. makePriceRecord(originalModel, zeroPriceData)
  330. );
  331. const session = createSession({
  332. originalModel,
  333. redirectedModel: originalModel,
  334. sessionId: "sess-lease-test-3",
  335. messageId: 5003,
  336. });
  337. // Override getResolvedPricingByBillingSource to return zero prices
  338. (
  339. session as {
  340. getResolvedPricingByBillingSource: () => Promise<{
  341. resolvedModelName: string;
  342. resolvedPricingProviderKey: string;
  343. source: string;
  344. priceData: ModelPriceData;
  345. }>;
  346. }
  347. ).getResolvedPricingByBillingSource = async () => ({
  348. resolvedModelName: originalModel,
  349. resolvedPricingProviderKey: "test-provider",
  350. source: "cloud_exact" as const,
  351. priceData: zeroPriceData,
  352. });
  353. const response = createNonStreamResponse(usage);
  354. await ProxyResponseHandler.dispatch(session, response);
  355. await drainAsyncTasks();
  356. // Zero cost should NOT trigger decrement
  357. expect(RateLimitService.decrementLeaseBudget).not.toHaveBeenCalled();
  358. });
  359. it("should call decrementLeaseBudget exactly once per request (no duplicates)", async () => {
  360. const session = createSession({
  361. originalModel,
  362. redirectedModel: originalModel,
  363. sessionId: "sess-lease-test-4",
  364. messageId: 5004,
  365. });
  366. const response = createNonStreamResponse(usage);
  367. await ProxyResponseHandler.dispatch(session, response);
  368. await drainAsyncTasks();
  369. // Each window/entity combo should be called exactly once
  370. const calls = vi.mocked(RateLimitService.decrementLeaseBudget).mock.calls;
  371. // Create a unique key for each call to check for duplicates
  372. const callKeys = calls.map((call) => `${call[0]}-${call[1]}-${call[2]}`);
  373. const uniqueKeys = new Set(callKeys);
  374. // No duplicates: unique keys should equal total calls
  375. expect(uniqueKeys.size).toBe(calls.length);
  376. expect(calls.length).toBe(12); // 4 windows x 3 entities
  377. });
  378. it("should use correct entity IDs from session", async () => {
  379. const customKeyId = 789;
  380. const customUserId = 321;
  381. const customProviderId = 111;
  382. const session = createSession({
  383. originalModel,
  384. redirectedModel: originalModel,
  385. sessionId: "sess-lease-test-5",
  386. messageId: 5005,
  387. });
  388. // Override with custom IDs
  389. session.setProvider({
  390. id: customProviderId,
  391. name: "custom-provider",
  392. providerType: "claude",
  393. costMultiplier: 1.0,
  394. dailyResetTime: "00:00",
  395. dailyResetMode: "fixed",
  396. } as unknown);
  397. session.setAuthState({
  398. user: {
  399. id: customUserId,
  400. name: "custom-user",
  401. dailyResetTime: "00:00",
  402. dailyResetMode: "fixed",
  403. },
  404. key: {
  405. id: customKeyId,
  406. name: "custom-key",
  407. dailyResetTime: "00:00",
  408. dailyResetMode: "fixed",
  409. },
  410. apiKey: "sk-custom",
  411. success: true,
  412. });
  413. session.setMessageContext({
  414. id: 5005,
  415. createdAt: new Date(),
  416. user: {
  417. id: customUserId,
  418. name: "custom-user",
  419. dailyResetTime: "00:00",
  420. dailyResetMode: "fixed",
  421. },
  422. key: {
  423. id: customKeyId,
  424. name: "custom-key",
  425. dailyResetTime: "00:00",
  426. dailyResetMode: "fixed",
  427. },
  428. apiKey: "sk-custom",
  429. });
  430. const response = createNonStreamResponse(usage);
  431. await ProxyResponseHandler.dispatch(session, response);
  432. await drainAsyncTasks();
  433. const calls = vi.mocked(RateLimitService.decrementLeaseBudget).mock.calls;
  434. // Verify key ID
  435. const keyCalls = calls.filter((c) => c[1] === "key");
  436. expect(keyCalls.every((c) => c[0] === customKeyId)).toBe(true);
  437. expect(keyCalls.length).toBe(4);
  438. // Verify user ID
  439. const userCalls = calls.filter((c) => c[1] === "user");
  440. expect(userCalls.every((c) => c[0] === customUserId)).toBe(true);
  441. expect(userCalls.length).toBe(4);
  442. // Verify provider ID
  443. const providerCalls = calls.filter((c) => c[1] === "provider");
  444. expect(providerCalls.every((c) => c[0] === customProviderId)).toBe(true);
  445. expect(providerCalls.length).toBe(4);
  446. });
  447. it("should use fire-and-forget pattern (not block on decrement failures)", async () => {
  448. // Mock decrementLeaseBudget to fail
  449. vi.mocked(RateLimitService.decrementLeaseBudget).mockRejectedValue(
  450. new Error("Redis connection failed")
  451. );
  452. const session = createSession({
  453. originalModel,
  454. redirectedModel: originalModel,
  455. sessionId: "sess-lease-test-6",
  456. messageId: 5006,
  457. });
  458. const response = createNonStreamResponse(usage);
  459. // Should NOT throw even if decrementLeaseBudget fails
  460. await expect(ProxyResponseHandler.dispatch(session, response)).resolves.toBeDefined();
  461. await drainAsyncTasks();
  462. // Verify decrement was attempted
  463. expect(RateLimitService.decrementLeaseBudget).toHaveBeenCalled();
  464. });
  465. });