rolling-window-cache-warm.test.ts 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
  2. const pipelineCommands: Array<unknown[]> = [];
  3. const pipeline = {
  4. zadd: vi.fn((...args: unknown[]) => {
  5. pipelineCommands.push(["zadd", ...args]);
  6. return pipeline;
  7. }),
  8. expire: vi.fn((...args: unknown[]) => {
  9. pipelineCommands.push(["expire", ...args]);
  10. return pipeline;
  11. }),
  12. incrbyfloat: vi.fn(() => pipeline),
  13. exec: vi.fn(async () => {
  14. pipelineCommands.push(["exec"]);
  15. return [];
  16. }),
  17. };
  18. const redisClient = {
  19. status: "ready",
  20. eval: vi.fn(async () => "0"),
  21. exists: vi.fn(async () => 0),
  22. pipeline: vi.fn(() => pipeline),
  23. get: vi.fn(async () => null),
  24. set: vi.fn(async () => "OK"),
  25. };
  26. vi.mock("@/lib/redis", () => ({
  27. getRedisClient: () => redisClient,
  28. }));
  29. vi.mock("@/lib/utils/timezone", () => ({
  30. resolveSystemTimezone: vi.fn(async () => "Asia/Shanghai"),
  31. }));
  32. const statisticsMock = {
  33. sumKeyTotalCost: vi.fn(async () => 0),
  34. sumUserCostToday: vi.fn(async () => 0),
  35. sumUserTotalCost: vi.fn(async () => 0),
  36. sumKeyCostInTimeRange: vi.fn(async () => 0),
  37. sumProviderCostInTimeRange: vi.fn(async () => 0),
  38. sumUserCostInTimeRange: vi.fn(async () => 0),
  39. findKeyCostEntriesInTimeRange: vi.fn(async () => []),
  40. findProviderCostEntriesInTimeRange: vi.fn(async () => []),
  41. findUserCostEntriesInTimeRange: vi.fn(async () => []),
  42. };
  43. vi.mock("@/repository/statistics", () => statisticsMock);
  44. describe("RateLimitService rolling window cache warm", () => {
  45. const nowMs = 1_700_000_000_000;
  46. beforeEach(() => {
  47. pipelineCommands.length = 0;
  48. vi.clearAllMocks();
  49. vi.useFakeTimers();
  50. vi.setSystemTime(new Date(nowMs));
  51. });
  52. afterEach(() => {
  53. vi.useRealTimers();
  54. });
  55. it("getCurrentCost(5h) rebuilds ZSET from DB entries on cache miss", async () => {
  56. statisticsMock.findKeyCostEntriesInTimeRange.mockResolvedValueOnce([
  57. { id: 101, createdAt: new Date(nowMs - 4 * 60 * 60 * 1000), costUsd: 1.5 },
  58. { id: 102, createdAt: new Date(nowMs - 1 * 60 * 60 * 1000), costUsd: 2.0 },
  59. ]);
  60. const { RateLimitService } = await import("@/lib/rate-limit");
  61. const current = await RateLimitService.getCurrentCost(1, "key", "5h");
  62. expect(current).toBeCloseTo(3.5, 10);
  63. const zaddCalls = pipelineCommands.filter((c) => c[0] === "zadd");
  64. expect(zaddCalls).toHaveLength(2);
  65. const expireCalls = pipelineCommands.filter((c) => c[0] === "expire");
  66. expect(expireCalls).toHaveLength(1);
  67. expect(expireCalls[0][1]).toBe("key:1:cost_5h_rolling");
  68. expect(expireCalls[0][2]).toBe(21600);
  69. // member format: `${createdAtMs}:${requestId}:${costUsd}`
  70. const first = zaddCalls[0];
  71. expect(first[1]).toBe("key:1:cost_5h_rolling");
  72. expect(first[2]).toBe(nowMs - 4 * 60 * 60 * 1000);
  73. expect(first[3]).toBe(`${nowMs - 4 * 60 * 60 * 1000}:101:1.5`);
  74. });
  75. it("trackCost passes requestId and uses createdAtMs for rolling windows", async () => {
  76. const { RateLimitService } = await import("@/lib/rate-limit");
  77. await RateLimitService.trackCost(1, 2, "sess", 0.5, {
  78. requestId: 123,
  79. createdAtMs: nowMs - 1000,
  80. keyResetMode: "fixed",
  81. providerResetMode: "fixed",
  82. });
  83. const evalCalls = redisClient.eval.mock.calls;
  84. expect(evalCalls.length).toBeGreaterThanOrEqual(2);
  85. const [firstCall] = evalCalls;
  86. expect(firstCall[2]).toBe("key:1:cost_5h_rolling");
  87. expect(firstCall[4]).toBe(String(nowMs - 1000));
  88. expect(firstCall[6]).toBe("123");
  89. });
  90. });