service-extra.test.ts 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442
  1. import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
  2. let redisClientRef: any;
  3. const pipelineCalls: Array<unknown[]> = [];
  4. const makePipeline = () => {
  5. const pipeline = {
  6. eval: vi.fn((...args: unknown[]) => {
  7. pipelineCalls.push(["eval", ...args]);
  8. return pipeline;
  9. }),
  10. get: vi.fn((...args: unknown[]) => {
  11. pipelineCalls.push(["get", ...args]);
  12. return pipeline;
  13. }),
  14. incrbyfloat: vi.fn((...args: unknown[]) => {
  15. pipelineCalls.push(["incrbyfloat", ...args]);
  16. return pipeline;
  17. }),
  18. expire: vi.fn((...args: unknown[]) => {
  19. pipelineCalls.push(["expire", ...args]);
  20. return pipeline;
  21. }),
  22. zremrangebyscore: vi.fn((...args: unknown[]) => {
  23. pipelineCalls.push(["zremrangebyscore", ...args]);
  24. return pipeline;
  25. }),
  26. zcard: vi.fn((...args: unknown[]) => {
  27. pipelineCalls.push(["zcard", ...args]);
  28. return pipeline;
  29. }),
  30. zadd: vi.fn((...args: unknown[]) => {
  31. pipelineCalls.push(["zadd", ...args]);
  32. return pipeline;
  33. }),
  34. exec: vi.fn(async () => {
  35. pipelineCalls.push(["exec"]);
  36. return [];
  37. }),
  38. };
  39. return pipeline;
  40. };
  41. vi.mock("@/lib/logger", () => ({
  42. logger: {
  43. debug: vi.fn(),
  44. info: vi.fn(),
  45. warn: vi.fn(),
  46. error: vi.fn(),
  47. },
  48. }));
  49. vi.mock("@/lib/redis", () => ({
  50. getRedisClient: () => redisClientRef,
  51. }));
  52. const resolveSystemTimezoneMock = vi.hoisted(() => vi.fn(async () => "Asia/Shanghai"));
  53. vi.mock("@/lib/utils/timezone", () => ({
  54. resolveSystemTimezone: resolveSystemTimezoneMock,
  55. }));
  56. const statisticsMock = {
  57. // service.ts 顶层静态导入需要这些 export 存在
  58. sumKeyTotalCost: vi.fn(async () => 0),
  59. sumUserTotalCost: vi.fn(async () => 0),
  60. sumUserCostInTimeRange: vi.fn(async () => 0),
  61. // getCurrentCost / checkCostLimitsFromDatabase 动态导入会解构这些 export
  62. findKeyCostEntriesInTimeRange: vi.fn(async () => []),
  63. findProviderCostEntriesInTimeRange: vi.fn(async () => []),
  64. findUserCostEntriesInTimeRange: vi.fn(async () => []),
  65. sumKeyCostInTimeRange: vi.fn(async () => 0),
  66. sumProviderCostInTimeRange: vi.fn(async () => 0),
  67. };
  68. vi.mock("@/repository/statistics", () => statisticsMock);
  69. const sessionTrackerMock = {
  70. getKeySessionCount: vi.fn(async () => 0),
  71. getProviderSessionCount: vi.fn(async () => 0),
  72. getUserSessionCount: vi.fn(async () => 0),
  73. };
  74. vi.mock("@/lib/session-tracker", () => ({
  75. SessionTracker: sessionTrackerMock,
  76. }));
  77. describe("RateLimitService - other quota paths", () => {
  78. const nowMs = 1_700_000_000_000;
  79. beforeEach(() => {
  80. vi.resetAllMocks();
  81. resolveSystemTimezoneMock.mockResolvedValue("Asia/Shanghai");
  82. pipelineCalls.length = 0;
  83. vi.useFakeTimers();
  84. vi.setSystemTime(new Date(nowMs));
  85. redisClientRef = {
  86. status: "ready",
  87. eval: vi.fn(async () => "0"),
  88. exists: vi.fn(async () => 1),
  89. get: vi.fn(async () => null),
  90. set: vi.fn(async () => "OK"),
  91. setex: vi.fn(async () => "OK"),
  92. pipeline: vi.fn(() => makePipeline()),
  93. };
  94. });
  95. afterEach(() => {
  96. vi.useRealTimers();
  97. });
  98. it("checkSessionLimit:limit<=0 时应放行", async () => {
  99. const { RateLimitService } = await import("@/lib/rate-limit");
  100. await expect(RateLimitService.checkSessionLimit(1, "key", 0)).resolves.toEqual({
  101. allowed: true,
  102. });
  103. });
  104. it("checkSessionLimit:Key 并发数达到上限时应拦截", async () => {
  105. const { RateLimitService } = await import("@/lib/rate-limit");
  106. sessionTrackerMock.getKeySessionCount.mockResolvedValueOnce(2);
  107. const result = await RateLimitService.checkSessionLimit(1, "key", 2);
  108. expect(result.allowed).toBe(false);
  109. expect(result.reason).toContain("Key并发 Session 上限已达到(2/2)");
  110. });
  111. it("checkSessionLimit:Provider 并发数未达上限时应放行", async () => {
  112. const { RateLimitService } = await import("@/lib/rate-limit");
  113. sessionTrackerMock.getProviderSessionCount.mockResolvedValueOnce(1);
  114. await expect(RateLimitService.checkSessionLimit(9, "provider", 2)).resolves.toEqual({
  115. allowed: true,
  116. });
  117. });
  118. it("checkAndTrackProviderSession:limit<=0 时应放行且不追踪", async () => {
  119. const { RateLimitService } = await import("@/lib/rate-limit");
  120. const result = await RateLimitService.checkAndTrackProviderSession(9, "sess", 0);
  121. expect(result).toEqual({ allowed: true, count: 0, tracked: false });
  122. });
  123. it("checkAndTrackProviderSession:Redis 非 ready 时应 Fail Open", async () => {
  124. const { RateLimitService } = await import("@/lib/rate-limit");
  125. redisClientRef.status = "end";
  126. const result = await RateLimitService.checkAndTrackProviderSession(9, "sess", 2);
  127. expect(result).toEqual({ allowed: true, count: 0, tracked: false });
  128. });
  129. it("checkAndTrackProviderSession:达到上限时应返回 not allowed", async () => {
  130. const { RateLimitService } = await import("@/lib/rate-limit");
  131. redisClientRef.eval.mockResolvedValueOnce([0, 2, 0]);
  132. const result = await RateLimitService.checkAndTrackProviderSession(9, "sess", 2);
  133. expect(result.allowed).toBe(false);
  134. expect(result.reason).toContain("供应商并发 Session 上限已达到(2/2)");
  135. });
  136. it("checkAndTrackProviderSession:未达到上限时应返回 allowed 且可标记 tracked", async () => {
  137. const { RateLimitService } = await import("@/lib/rate-limit");
  138. redisClientRef.eval.mockResolvedValueOnce([1, 1, 1]);
  139. const result = await RateLimitService.checkAndTrackProviderSession(9, "sess", 2);
  140. expect(result).toEqual({ allowed: true, count: 1, tracked: true });
  141. });
  142. it("checkAndTrackProviderSession: should pass SESSION_TTL_MS as ARGV[4] to Lua script", async () => {
  143. const { RateLimitService } = await import("@/lib/rate-limit");
  144. redisClientRef.eval.mockResolvedValueOnce([1, 1, 1]);
  145. await RateLimitService.checkAndTrackProviderSession(9, "sess", 2);
  146. // Verify eval was called with the correct args including ARGV[4] = SESSION_TTL_MS
  147. expect(redisClientRef.eval).toHaveBeenCalledTimes(1);
  148. const evalCall = redisClientRef.eval.mock.calls[0];
  149. // evalCall: [script, numkeys, key, sessionId, limit, now, ttlMs]
  150. // Indices: 0 1 2 3 4 5 6
  151. expect(evalCall.length).toBe(7); // script + 1 key + 5 ARGV
  152. // ARGV[4] (index 6) should be SESSION_TTL_MS derived from env (default 300s = 300000ms)
  153. const ttlMsArg = evalCall[6];
  154. expect(ttlMsArg).toBe("300000");
  155. });
  156. it("trackUserDailyCost:fixed 模式应使用 STRING + TTL", async () => {
  157. const { RateLimitService } = await import("@/lib/rate-limit");
  158. await RateLimitService.trackUserDailyCost(1, 1.25, "00:00", "fixed");
  159. expect(pipelineCalls.some((c) => c[0] === "incrbyfloat")).toBe(true);
  160. expect(pipelineCalls.some((c) => c[0] === "expire")).toBe(true);
  161. });
  162. it("trackUserDailyCost:rolling 模式应使用 ZSET Lua 脚本", async () => {
  163. const { RateLimitService } = await import("@/lib/rate-limit");
  164. await RateLimitService.trackUserDailyCost(1, 1.25, "00:00", "rolling", { requestId: 123 });
  165. expect(redisClientRef.eval).toHaveBeenCalled();
  166. });
  167. it("checkUserRPM:达到上限时应拦截", async () => {
  168. const { RateLimitService } = await import("@/lib/rate-limit");
  169. const pipeline = makePipeline();
  170. pipeline.exec
  171. .mockResolvedValueOnce([
  172. [null, 0],
  173. [null, 5], // zcard 返回 5
  174. ])
  175. .mockResolvedValueOnce([]); // 写入 pipeline
  176. redisClientRef.pipeline.mockReturnValueOnce(pipeline);
  177. const result = await RateLimitService.checkUserRPM(1, 5);
  178. expect(result.allowed).toBe(false);
  179. expect(result.current).toBe(5);
  180. });
  181. it("checkUserRPM:未达到上限时应写入本次请求并放行", async () => {
  182. const { RateLimitService } = await import("@/lib/rate-limit");
  183. const readPipeline = makePipeline();
  184. readPipeline.exec.mockResolvedValueOnce([
  185. [null, 0],
  186. [null, 3], // zcard 返回 3
  187. ]);
  188. const writePipeline = makePipeline();
  189. writePipeline.exec.mockResolvedValueOnce([]);
  190. redisClientRef.pipeline.mockReturnValueOnce(readPipeline).mockReturnValueOnce(writePipeline);
  191. const result = await RateLimitService.checkUserRPM(1, 5);
  192. expect(result.allowed).toBe(true);
  193. expect(result.current).toBe(4);
  194. expect(writePipeline.zadd).toHaveBeenCalledTimes(1);
  195. });
  196. it("checkRpmLimit:user 类型应复用 checkUserRPM 逻辑", async () => {
  197. const { RateLimitService } = await import("@/lib/rate-limit");
  198. const readPipeline = makePipeline();
  199. readPipeline.exec.mockResolvedValueOnce([
  200. [null, 0],
  201. [null, 1],
  202. ]);
  203. const writePipeline = makePipeline();
  204. writePipeline.exec.mockResolvedValueOnce([]);
  205. redisClientRef.pipeline.mockReturnValueOnce(readPipeline).mockReturnValueOnce(writePipeline);
  206. const result = await RateLimitService.checkRpmLimit(1, "user", 2);
  207. expect(result.allowed).toBe(true);
  208. expect(result.current).toBe(2);
  209. });
  210. it("getCurrentCostBatch:providerIds 为空时应返回空 Map", async () => {
  211. const { RateLimitService } = await import("@/lib/rate-limit");
  212. const result = await RateLimitService.getCurrentCostBatch([], new Map());
  213. expect(result.size).toBe(0);
  214. });
  215. it("getCurrentCostBatch:Redis 非 ready 时应返回默认 0", async () => {
  216. const { RateLimitService } = await import("@/lib/rate-limit");
  217. redisClientRef.status = "end";
  218. const result = await RateLimitService.getCurrentCostBatch([1], new Map());
  219. expect(result.get(1)).toEqual({ cost5h: 0, costDaily: 0, costWeekly: 0, costMonthly: 0 });
  220. });
  221. it("getCurrentCostBatch:应按 pipeline 返回解析 5h/daily/weekly/monthly", async () => {
  222. const { RateLimitService } = await import("@/lib/rate-limit");
  223. const pipeline = makePipeline();
  224. // queryMeta: 5h(eval), daily(get fixed), weekly(get), monthly(get)
  225. pipeline.exec.mockResolvedValueOnce([
  226. [null, "1.5"],
  227. [null, "2.5"],
  228. [null, "3.5"],
  229. [null, "4.5"],
  230. ]);
  231. redisClientRef.pipeline.mockReturnValueOnce(pipeline);
  232. const dailyResetConfigs = new Map<
  233. number,
  234. { resetTime?: string | null; resetMode?: string | null }
  235. >();
  236. dailyResetConfigs.set(1, { resetTime: "00:00", resetMode: "fixed" });
  237. const result = await RateLimitService.getCurrentCostBatch([1], dailyResetConfigs);
  238. expect(result.get(1)).toEqual({
  239. cost5h: 1.5,
  240. costDaily: 2.5,
  241. costWeekly: 3.5,
  242. costMonthly: 4.5,
  243. });
  244. });
  245. it("checkCostLimits:5h 滚动窗口超限时应返回 not allowed", async () => {
  246. const { RateLimitService } = await import("@/lib/rate-limit");
  247. redisClientRef.eval.mockResolvedValueOnce("11");
  248. const result = await RateLimitService.checkCostLimits(1, "provider", {
  249. limit_5h_usd: 10,
  250. limit_daily_usd: null,
  251. limit_weekly_usd: null,
  252. limit_monthly_usd: null,
  253. });
  254. expect(result.allowed).toBe(false);
  255. expect(result.reason).toContain("供应商 5小时消费上限已达到(11.0000/10)");
  256. });
  257. it("checkCostLimits:daily rolling cache miss 时应回退 DB 并 warm ZSET", async () => {
  258. const { RateLimitService } = await import("@/lib/rate-limit");
  259. redisClientRef.eval.mockResolvedValueOnce("0");
  260. redisClientRef.exists.mockResolvedValueOnce(0);
  261. statisticsMock.findProviderCostEntriesInTimeRange.mockResolvedValueOnce([
  262. { id: 101, createdAt: new Date(nowMs - 60_000), costUsd: 3 },
  263. { id: 102, createdAt: new Date(nowMs - 30_000), costUsd: 9 },
  264. ]);
  265. const result = await RateLimitService.checkCostLimits(9, "provider", {
  266. limit_5h_usd: null,
  267. limit_daily_usd: 10,
  268. daily_reset_mode: "rolling",
  269. daily_reset_time: "00:00",
  270. limit_weekly_usd: null,
  271. limit_monthly_usd: null,
  272. });
  273. expect(result.allowed).toBe(false);
  274. expect(result.reason).toContain("供应商 每日消费上限已达到(12.0000/10)");
  275. expect(pipelineCalls.some((c) => c[0] === "zadd")).toBe(true);
  276. });
  277. it("getCurrentCost:daily fixed cache hit 时应直接返回当前值", async () => {
  278. const { RateLimitService } = await import("@/lib/rate-limit");
  279. redisClientRef.get.mockImplementation(async (key: string) => {
  280. if (key === "provider:9:cost_daily_0000") return "7.5";
  281. return null;
  282. });
  283. const current = await RateLimitService.getCurrentCost(9, "provider", "daily", "00:00", "fixed");
  284. expect(current).toBeCloseTo(7.5, 10);
  285. });
  286. it("getCurrentCost:daily rolling cache miss 时应从 DB 重建并返回", async () => {
  287. const { RateLimitService } = await import("@/lib/rate-limit");
  288. redisClientRef.eval.mockResolvedValueOnce("0");
  289. redisClientRef.exists.mockResolvedValueOnce(0);
  290. statisticsMock.findProviderCostEntriesInTimeRange.mockResolvedValueOnce([
  291. { id: 101, createdAt: new Date(nowMs - 60_000), costUsd: 2 },
  292. { id: 102, createdAt: new Date(nowMs - 30_000), costUsd: 3 },
  293. ]);
  294. const current = await RateLimitService.getCurrentCost(
  295. 9,
  296. "provider",
  297. "daily",
  298. "00:00",
  299. "rolling"
  300. );
  301. expect(current).toBeCloseTo(5, 10);
  302. expect(pipelineCalls.some((c) => c[0] === "zadd")).toBe(true);
  303. });
  304. it("trackCost:fixed 模式应写入 key/provider 的 daily+weekly+monthly(STRING)", async () => {
  305. const { RateLimitService } = await import("@/lib/rate-limit");
  306. await RateLimitService.trackCost(1, 9, "sess", 1.25, {
  307. keyResetMode: "fixed",
  308. providerResetMode: "fixed",
  309. keyResetTime: "00:00",
  310. providerResetTime: "00:00",
  311. requestId: 123,
  312. createdAtMs: nowMs,
  313. });
  314. // 5h 的 Lua 脚本至少会执行两次(key/provider)
  315. expect(redisClientRef.eval).toHaveBeenCalled();
  316. expect(pipelineCalls.filter((c) => c[0] === "incrbyfloat").length).toBeGreaterThanOrEqual(4);
  317. expect(pipelineCalls.filter((c) => c[0] === "expire").length).toBeGreaterThanOrEqual(4);
  318. });
  319. it("trackCost:rolling 模式应写入 key/provider 的 daily_rolling(ZSET)", async () => {
  320. const { RateLimitService } = await import("@/lib/rate-limit");
  321. await RateLimitService.trackCost(1, 9, "sess", 1.25, {
  322. keyResetMode: "rolling",
  323. providerResetMode: "rolling",
  324. requestId: 123,
  325. createdAtMs: nowMs,
  326. });
  327. const evalArgs = redisClientRef.eval.mock.calls.map((c: unknown[]) => String(c[2]));
  328. expect(evalArgs.some((k) => k === "key:1:cost_daily_rolling")).toBe(true);
  329. expect(evalArgs.some((k) => k === "provider:9:cost_daily_rolling")).toBe(true);
  330. });
  331. it("getCurrentCostBatch:pipeline.exec 返回 null 时应返回默认值", async () => {
  332. const { RateLimitService } = await import("@/lib/rate-limit");
  333. const pipeline = makePipeline();
  334. pipeline.exec.mockResolvedValueOnce(null);
  335. redisClientRef.pipeline.mockReturnValueOnce(pipeline);
  336. const result = await RateLimitService.getCurrentCostBatch([1], new Map());
  337. expect(result.get(1)).toEqual({ cost5h: 0, costDaily: 0, costWeekly: 0, costMonthly: 0 });
  338. });
  339. it("getCurrentCostBatch:单个 query 出错时应跳过该项", async () => {
  340. const { RateLimitService } = await import("@/lib/rate-limit");
  341. const pipeline = makePipeline();
  342. pipeline.exec.mockResolvedValueOnce([
  343. [new Error("boom"), null],
  344. [null, "2.5"],
  345. [null, "3.5"],
  346. [null, "4.5"],
  347. ]);
  348. redisClientRef.pipeline.mockReturnValueOnce(pipeline);
  349. const result = await RateLimitService.getCurrentCostBatch([1], new Map());
  350. // 5h 出错,保持默认 0,其余正常
  351. expect(result.get(1)).toEqual({ cost5h: 0, costDaily: 2.5, costWeekly: 3.5, costMonthly: 4.5 });
  352. });
  353. });