cost-limits.test.ts 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
  2. const pipelineCommands: Array<unknown[]> = [];
  3. const pipeline = {
  4. zadd: vi.fn((...args: unknown[]) => {
  5. pipelineCommands.push(["zadd", ...args]);
  6. return pipeline;
  7. }),
  8. expire: vi.fn((...args: unknown[]) => {
  9. pipelineCommands.push(["expire", ...args]);
  10. return pipeline;
  11. }),
  12. exec: vi.fn(async () => {
  13. pipelineCommands.push(["exec"]);
  14. return [];
  15. }),
  16. incrbyfloat: vi.fn(() => pipeline),
  17. zremrangebyscore: vi.fn(() => pipeline),
  18. zcard: vi.fn(() => pipeline),
  19. };
  20. const redisClient = {
  21. status: "ready",
  22. eval: vi.fn(async () => "0"),
  23. exists: vi.fn(async () => 1),
  24. get: vi.fn(async () => null),
  25. set: vi.fn(async () => "OK"),
  26. setex: vi.fn(async () => "OK"),
  27. pipeline: vi.fn(() => pipeline),
  28. };
  29. vi.mock("@/lib/redis", () => ({
  30. getRedisClient: () => redisClient,
  31. }));
  32. const resolveSystemTimezoneMock = vi.hoisted(() => vi.fn(async () => "Asia/Shanghai"));
  33. vi.mock("@/lib/utils/timezone", () => ({
  34. resolveSystemTimezone: resolveSystemTimezoneMock,
  35. }));
  36. const statisticsMock = {
  37. // total cost
  38. sumKeyTotalCost: vi.fn(async () => 0),
  39. sumUserTotalCost: vi.fn(async () => 0),
  40. sumProviderTotalCost: vi.fn(async () => 0),
  41. // fixed-window sums
  42. sumKeyCostInTimeRange: vi.fn(async () => 0),
  43. sumProviderCostInTimeRange: vi.fn(async () => 0),
  44. sumUserCostInTimeRange: vi.fn(async () => 0),
  45. // rolling-window entries
  46. findKeyCostEntriesInTimeRange: vi.fn(async () => []),
  47. findProviderCostEntriesInTimeRange: vi.fn(async () => []),
  48. findUserCostEntriesInTimeRange: vi.fn(async () => []),
  49. };
  50. vi.mock("@/repository/statistics", () => statisticsMock);
  51. describe("RateLimitService - cost limits and quota checks", () => {
  52. const nowMs = 1_700_000_000_000;
  53. beforeEach(() => {
  54. pipelineCommands.length = 0;
  55. vi.resetAllMocks();
  56. resolveSystemTimezoneMock.mockResolvedValue("Asia/Shanghai");
  57. vi.useFakeTimers();
  58. vi.setSystemTime(new Date(nowMs));
  59. });
  60. afterEach(() => {
  61. vi.useRealTimers();
  62. });
  63. it("checkCostLimits:未设置任何限额时应直接放行", async () => {
  64. const { RateLimitService } = await import("@/lib/rate-limit");
  65. const result = await RateLimitService.checkCostLimits(1, "key", {
  66. limit_5h_usd: null,
  67. limit_daily_usd: null,
  68. limit_weekly_usd: null,
  69. limit_monthly_usd: null,
  70. });
  71. expect(result).toEqual({ allowed: true });
  72. expect(redisClient.eval).not.toHaveBeenCalled();
  73. expect(redisClient.get).not.toHaveBeenCalled();
  74. });
  75. it("checkCostLimits:Key 每日 fixed 超限时应返回 not allowed", async () => {
  76. const { RateLimitService } = await import("@/lib/rate-limit");
  77. redisClient.get.mockImplementation(async (key: string) => {
  78. if (key === "key:1:cost_daily_0000") return "12";
  79. return "0";
  80. });
  81. const result = await RateLimitService.checkCostLimits(1, "key", {
  82. limit_5h_usd: null,
  83. limit_daily_usd: 10,
  84. daily_reset_mode: "fixed",
  85. daily_reset_time: "00:00",
  86. limit_weekly_usd: null,
  87. limit_monthly_usd: null,
  88. });
  89. expect(result.allowed).toBe(false);
  90. expect(result.reason).toContain("Key 每日消费上限已达到(12.0000/10)");
  91. });
  92. it("checkCostLimits:Provider 每日 rolling 超限时应返回 not allowed", async () => {
  93. const { RateLimitService } = await import("@/lib/rate-limit");
  94. redisClient.eval.mockResolvedValueOnce("11");
  95. const result = await RateLimitService.checkCostLimits(9, "provider", {
  96. limit_5h_usd: null,
  97. limit_daily_usd: 10,
  98. daily_reset_mode: "rolling",
  99. daily_reset_time: "00:00",
  100. limit_weekly_usd: null,
  101. limit_monthly_usd: null,
  102. });
  103. expect(result.allowed).toBe(false);
  104. expect(result.reason).toContain("供应商 每日消费上限已达到(11.0000/10)");
  105. });
  106. it("checkCostLimits:User fast-path 的类型标识应为 User(避免错误标为“供应商”)", async () => {
  107. const { RateLimitService } = await import("@/lib/rate-limit");
  108. redisClient.get.mockImplementation(async (key: string) => {
  109. if (key === "user:1:cost_weekly") return "20";
  110. return "0";
  111. });
  112. const result = await RateLimitService.checkCostLimits(1, "user", {
  113. limit_5h_usd: null,
  114. limit_daily_usd: null,
  115. limit_weekly_usd: 10,
  116. limit_monthly_usd: null,
  117. });
  118. expect(result.allowed).toBe(false);
  119. expect(result.reason).toContain("User 周消费上限已达到(20.0000/10)");
  120. });
  121. it("checkCostLimits:Redis cache miss 时应 fallback 到 DB 查询", async () => {
  122. const { RateLimitService } = await import("@/lib/rate-limit");
  123. redisClient.get.mockResolvedValueOnce(null);
  124. statisticsMock.sumKeyCostInTimeRange.mockResolvedValueOnce(20);
  125. const result = await RateLimitService.checkCostLimits(1, "key", {
  126. limit_5h_usd: null,
  127. limit_daily_usd: 10,
  128. daily_reset_mode: "fixed",
  129. daily_reset_time: "00:00",
  130. limit_weekly_usd: null,
  131. limit_monthly_usd: null,
  132. });
  133. expect(result.allowed).toBe(false);
  134. expect(statisticsMock.sumKeyCostInTimeRange).toHaveBeenCalledTimes(1);
  135. expect(redisClient.set).toHaveBeenCalled();
  136. });
  137. it("checkTotalCostLimit:limitTotalUsd 未设置时应放行", async () => {
  138. const { RateLimitService } = await import("@/lib/rate-limit");
  139. expect(await RateLimitService.checkTotalCostLimit(1, "user", null)).toEqual({ allowed: true });
  140. expect(await RateLimitService.checkTotalCostLimit(1, "user", undefined as any)).toEqual({
  141. allowed: true,
  142. });
  143. expect(await RateLimitService.checkTotalCostLimit(1, "user", 0)).toEqual({ allowed: true });
  144. });
  145. it("checkTotalCostLimit:Key 缺失 keyHash 时应跳过 enforcement", async () => {
  146. const { RateLimitService } = await import("@/lib/rate-limit");
  147. const result = await RateLimitService.checkTotalCostLimit(1, "key", 10, undefined);
  148. expect(result).toEqual({ allowed: true });
  149. });
  150. it("checkTotalCostLimit:Redis cache hit 且已超限时应返回 not allowed", async () => {
  151. const { RateLimitService } = await import("@/lib/rate-limit");
  152. redisClient.get.mockImplementation(async (key: string) => {
  153. if (key === "total_cost:user:7") return "20";
  154. return null;
  155. });
  156. const result = await RateLimitService.checkTotalCostLimit(7, "user", 10);
  157. expect(result.allowed).toBe(false);
  158. expect(result.current).toBe(20);
  159. });
  160. it("checkTotalCostLimit:Redis miss 时应 fallback DB 并写回缓存", async () => {
  161. const { RateLimitService } = await import("@/lib/rate-limit");
  162. redisClient.get.mockResolvedValueOnce(null);
  163. statisticsMock.sumUserTotalCost.mockResolvedValueOnce(5);
  164. const result = await RateLimitService.checkTotalCostLimit(7, "user", 10);
  165. expect(result.allowed).toBe(true);
  166. expect(result.current).toBe(5);
  167. expect(redisClient.setex).toHaveBeenCalledWith("total_cost:user:7", 300, "5");
  168. });
  169. it("checkTotalCostLimit:Provider Redis miss 时应 fallback DB 并写回缓存(cache key 应包含 resetAt)", async () => {
  170. const { RateLimitService } = await import("@/lib/rate-limit");
  171. const resetAt = new Date(nowMs - 123_000);
  172. redisClient.get.mockResolvedValueOnce(null);
  173. statisticsMock.sumProviderTotalCost.mockResolvedValueOnce(5);
  174. const result = await RateLimitService.checkTotalCostLimit(9, "provider", 10, {
  175. resetAt,
  176. });
  177. expect(result.allowed).toBe(true);
  178. expect(result.current).toBe(5);
  179. expect(statisticsMock.sumProviderTotalCost).toHaveBeenCalledTimes(1);
  180. expect(statisticsMock.sumProviderTotalCost).toHaveBeenCalledWith(9, resetAt);
  181. expect(redisClient.setex).toHaveBeenCalledWith(
  182. `total_cost:provider:9:${resetAt.getTime()}`,
  183. 300,
  184. "5"
  185. );
  186. });
  187. it("checkTotalCostLimit:Provider resetAt 为空时应使用 none key 并回退到 DB", async () => {
  188. const { RateLimitService } = await import("@/lib/rate-limit");
  189. redisClient.get.mockResolvedValueOnce(null);
  190. statisticsMock.sumProviderTotalCost.mockResolvedValueOnce(5);
  191. const result = await RateLimitService.checkTotalCostLimit(9, "provider", 10, {
  192. resetAt: null,
  193. });
  194. expect(result.allowed).toBe(true);
  195. expect(result.current).toBe(5);
  196. expect(statisticsMock.sumProviderTotalCost).toHaveBeenCalledWith(9, null);
  197. expect(redisClient.setex).toHaveBeenCalledWith("total_cost:provider:9:none", 300, "5");
  198. });
  199. it("checkTotalCostLimit:Provider Redis cache hit 且已超限时应返回 not allowed(按 resetAt key 命中)", async () => {
  200. const { RateLimitService } = await import("@/lib/rate-limit");
  201. const resetAt = new Date(nowMs - 456_000);
  202. redisClient.get.mockImplementation(async (key: string) => {
  203. if (key === `total_cost:provider:9:${resetAt.getTime()}`) return "20";
  204. return null;
  205. });
  206. const result = await RateLimitService.checkTotalCostLimit(9, "provider", 10, {
  207. resetAt,
  208. });
  209. expect(result.allowed).toBe(false);
  210. expect(result.current).toBe(20);
  211. });
  212. it("checkUserDailyCost:fixed 模式 cache hit 超限时应拦截", async () => {
  213. const { RateLimitService } = await import("@/lib/rate-limit");
  214. redisClient.get.mockImplementation(async (key: string) => {
  215. if (key === "user:1:cost_daily_0000") return "20";
  216. return null;
  217. });
  218. const result = await RateLimitService.checkUserDailyCost(1, 10, "00:00", "fixed");
  219. expect(result.allowed).toBe(false);
  220. expect(result.current).toBe(20);
  221. });
  222. it("checkUserDailyCost:fixed 模式 cache miss 时应 fallback DB 并写回缓存", async () => {
  223. const { RateLimitService } = await import("@/lib/rate-limit");
  224. redisClient.get.mockResolvedValueOnce(null);
  225. statisticsMock.sumUserCostInTimeRange.mockResolvedValueOnce(12);
  226. const result = await RateLimitService.checkUserDailyCost(1, 10, "00:00", "fixed");
  227. expect(result.allowed).toBe(false);
  228. expect(result.current).toBe(12);
  229. expect(redisClient.set).toHaveBeenCalled();
  230. });
  231. it("checkUserDailyCost:rolling 模式 cache miss 时应走明细查询并 warm ZSET", async () => {
  232. const { RateLimitService } = await import("@/lib/rate-limit");
  233. redisClient.eval.mockResolvedValueOnce("0");
  234. redisClient.exists.mockResolvedValueOnce(0);
  235. statisticsMock.findUserCostEntriesInTimeRange.mockResolvedValueOnce([
  236. { id: 101, createdAt: new Date(nowMs - 60_000), costUsd: 3 },
  237. { id: 102, createdAt: new Date(nowMs - 30_000), costUsd: 8 },
  238. ]);
  239. const result = await RateLimitService.checkUserDailyCost(1, 10, "00:00", "rolling");
  240. expect(result.allowed).toBe(false);
  241. expect(result.current).toBe(11);
  242. const zaddCalls = pipelineCommands.filter((c) => c[0] === "zadd");
  243. expect(zaddCalls).toHaveLength(2);
  244. expect(pipelineCommands.some((c) => c[0] === "expire")).toBe(true);
  245. });
  246. });