service-extra.test.ts 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523
  1. import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
  2. let redisClientRef: any;
  3. const pipelineCalls: Array<unknown[]> = [];
  4. const makePipeline = () => {
  5. const pipeline = {
  6. eval: vi.fn((...args: unknown[]) => {
  7. pipelineCalls.push(["eval", ...args]);
  8. return pipeline;
  9. }),
  10. get: vi.fn((...args: unknown[]) => {
  11. pipelineCalls.push(["get", ...args]);
  12. return pipeline;
  13. }),
  14. incrbyfloat: vi.fn((...args: unknown[]) => {
  15. pipelineCalls.push(["incrbyfloat", ...args]);
  16. return pipeline;
  17. }),
  18. expire: vi.fn((...args: unknown[]) => {
  19. pipelineCalls.push(["expire", ...args]);
  20. return pipeline;
  21. }),
  22. zremrangebyscore: vi.fn((...args: unknown[]) => {
  23. pipelineCalls.push(["zremrangebyscore", ...args]);
  24. return pipeline;
  25. }),
  26. zcard: vi.fn((...args: unknown[]) => {
  27. pipelineCalls.push(["zcard", ...args]);
  28. return pipeline;
  29. }),
  30. zadd: vi.fn((...args: unknown[]) => {
  31. pipelineCalls.push(["zadd", ...args]);
  32. return pipeline;
  33. }),
  34. exec: vi.fn(async () => {
  35. pipelineCalls.push(["exec"]);
  36. return [];
  37. }),
  38. };
  39. return pipeline;
  40. };
  41. vi.mock("@/lib/logger", () => ({
  42. logger: {
  43. debug: vi.fn(),
  44. info: vi.fn(),
  45. warn: vi.fn(),
  46. error: vi.fn(),
  47. },
  48. }));
  49. vi.mock("@/lib/redis", () => ({
  50. getRedisClient: () => redisClientRef,
  51. }));
  52. const resolveSystemTimezoneMock = vi.hoisted(() => vi.fn(async () => "Asia/Shanghai"));
  53. vi.mock("@/lib/utils/timezone", () => ({
  54. resolveSystemTimezone: resolveSystemTimezoneMock,
  55. }));
  56. const statisticsMock = {
  57. // service.ts 顶层静态导入需要这些 export 存在
  58. sumKeyTotalCost: vi.fn(async () => 0),
  59. sumUserTotalCost: vi.fn(async () => 0),
  60. sumUserCostInTimeRange: vi.fn(async () => 0),
  61. // getCurrentCost / checkCostLimitsFromDatabase 动态导入会解构这些 export
  62. findKeyCostEntriesInTimeRange: vi.fn(async () => []),
  63. findProviderCostEntriesInTimeRange: vi.fn(async () => []),
  64. findUserCostEntriesInTimeRange: vi.fn(async () => []),
  65. sumKeyCostInTimeRange: vi.fn(async () => 0),
  66. sumProviderCostInTimeRange: vi.fn(async () => 0),
  67. };
  68. vi.mock("@/repository/statistics", () => statisticsMock);
  69. const sessionTrackerMock = {
  70. getKeySessionCount: vi.fn(async () => 0),
  71. getProviderSessionCount: vi.fn(async () => 0),
  72. getUserSessionCount: vi.fn(async () => 0),
  73. };
  74. vi.mock("@/lib/session-tracker", () => ({
  75. SessionTracker: sessionTrackerMock,
  76. }));
  77. describe("RateLimitService - other quota paths", () => {
  78. const nowMs = 1_700_000_000_000;
  79. beforeEach(() => {
  80. vi.resetAllMocks();
  81. resolveSystemTimezoneMock.mockResolvedValue("Asia/Shanghai");
  82. pipelineCalls.length = 0;
  83. vi.useFakeTimers();
  84. vi.setSystemTime(new Date(nowMs));
  85. redisClientRef = {
  86. status: "ready",
  87. eval: vi.fn(async () => "0"),
  88. exists: vi.fn(async () => 1),
  89. get: vi.fn(async () => null),
  90. set: vi.fn(async () => "OK"),
  91. setex: vi.fn(async () => "OK"),
  92. pipeline: vi.fn(() => makePipeline()),
  93. };
  94. });
  95. afterEach(() => {
  96. vi.useRealTimers();
  97. });
  98. it("checkSessionLimit:limit<=0 时应放行", async () => {
  99. const { RateLimitService } = await import("@/lib/rate-limit");
  100. await expect(RateLimitService.checkSessionLimit(1, "key", 0)).resolves.toEqual({
  101. allowed: true,
  102. });
  103. });
  104. it("checkSessionLimit:Key 并发数达到上限时应拦截", async () => {
  105. const { RateLimitService } = await import("@/lib/rate-limit");
  106. sessionTrackerMock.getKeySessionCount.mockResolvedValueOnce(2);
  107. const result = await RateLimitService.checkSessionLimit(1, "key", 2);
  108. expect(result.allowed).toBe(false);
  109. expect(result.reason).toContain("Key并发 Session 上限已达到(2/2)");
  110. });
  111. it("checkSessionLimit:Provider 并发数未达上限时应放行", async () => {
  112. const { RateLimitService } = await import("@/lib/rate-limit");
  113. sessionTrackerMock.getProviderSessionCount.mockResolvedValueOnce(1);
  114. await expect(RateLimitService.checkSessionLimit(9, "provider", 2)).resolves.toEqual({
  115. allowed: true,
  116. });
  117. });
  118. it("checkAndTrackProviderSession:limit<=0 时应放行且不追踪", async () => {
  119. const { RateLimitService } = await import("@/lib/rate-limit");
  120. const result = await RateLimitService.checkAndTrackProviderSession(9, "sess", 0);
  121. expect(result).toEqual({ allowed: true, count: 0, tracked: false });
  122. });
  123. it("checkAndTrackProviderSession:Redis 非 ready 时应 Fail Open", async () => {
  124. const { RateLimitService } = await import("@/lib/rate-limit");
  125. redisClientRef.status = "end";
  126. const result = await RateLimitService.checkAndTrackProviderSession(9, "sess", 2);
  127. expect(result).toEqual({ allowed: true, count: 0, tracked: false });
  128. });
  129. it("checkAndTrackProviderSession:达到上限时应返回 not allowed", async () => {
  130. const { RateLimitService } = await import("@/lib/rate-limit");
  131. redisClientRef.eval.mockResolvedValueOnce([0, 2, 0]);
  132. const result = await RateLimitService.checkAndTrackProviderSession(9, "sess", 2);
  133. expect(result.allowed).toBe(false);
  134. expect(result.reason).toContain("供应商并发 Session 上限已达到(2/2)");
  135. });
  136. it("checkAndTrackProviderSession:未达到上限时应返回 allowed 且可标记 tracked", async () => {
  137. const { RateLimitService } = await import("@/lib/rate-limit");
  138. redisClientRef.eval.mockResolvedValueOnce([1, 1, 1]);
  139. const result = await RateLimitService.checkAndTrackProviderSession(9, "sess", 2);
  140. expect(result).toEqual({ allowed: true, count: 1, tracked: true });
  141. });
  142. it("checkAndTrackProviderSession: should pass SESSION_TTL_MS as ARGV[4] to Lua script", async () => {
  143. const { RateLimitService } = await import("@/lib/rate-limit");
  144. redisClientRef.eval.mockResolvedValueOnce([1, 1, 1]);
  145. await RateLimitService.checkAndTrackProviderSession(9, "sess", 2);
  146. // Verify eval was called with the correct args including ARGV[4] = SESSION_TTL_MS
  147. expect(redisClientRef.eval).toHaveBeenCalledTimes(1);
  148. const evalCall = redisClientRef.eval.mock.calls[0];
  149. // evalCall: [script, numkeys, key, sessionId, limit, now, ttlMs]
  150. // Indices: 0 1 2 3 4 5 6
  151. expect(evalCall.length).toBe(7); // script + 1 key + 5 ARGV
  152. // ARGV[4] (index 6) should be SESSION_TTL_MS derived from env (default 300s = 300000ms)
  153. const ttlMsArg = evalCall[6];
  154. expect(ttlMsArg).toBe("300000");
  155. });
  156. it("checkAndTrackKeyUserSession:keyLimit/userLimit 均 <=0 时应放行且不追踪", async () => {
  157. const { RateLimitService } = await import("@/lib/rate-limit");
  158. const result = await RateLimitService.checkAndTrackKeyUserSession(2, 1, "sess", 0, 0);
  159. expect(result).toEqual({
  160. allowed: true,
  161. keyCount: 0,
  162. userCount: 0,
  163. trackedKey: false,
  164. trackedUser: false,
  165. });
  166. expect(redisClientRef.eval).not.toHaveBeenCalled();
  167. });
  168. it("checkAndTrackKeyUserSession:Redis 非 ready 时应 Fail Open", async () => {
  169. const { RateLimitService } = await import("@/lib/rate-limit");
  170. redisClientRef.status = "end";
  171. const result = await RateLimitService.checkAndTrackKeyUserSession(2, 1, "sess", 2, 2);
  172. expect(result).toEqual({
  173. allowed: true,
  174. keyCount: 0,
  175. userCount: 0,
  176. trackedKey: false,
  177. trackedUser: false,
  178. });
  179. });
  180. it("checkAndTrackKeyUserSession:Key 超限时应返回 not allowed", async () => {
  181. const { RateLimitService } = await import("@/lib/rate-limit");
  182. redisClientRef.eval.mockResolvedValueOnce([0, 1, 2, 0, 1, 0]);
  183. const result = await RateLimitService.checkAndTrackKeyUserSession(2, 1, "sess", 2, 10);
  184. expect(result.allowed).toBe(false);
  185. expect(result.rejectedBy).toBe("key");
  186. expect(result.reasonCode).toBe("RATE_LIMIT_CONCURRENT_SESSIONS_EXCEEDED");
  187. expect(result.reasonParams).toEqual({ current: 2, limit: 2, target: "key" });
  188. });
  189. it("checkAndTrackKeyUserSession:User 超限时应返回 not allowed", async () => {
  190. const { RateLimitService } = await import("@/lib/rate-limit");
  191. redisClientRef.eval.mockResolvedValueOnce([0, 2, 1, 0, 2, 0]);
  192. const result = await RateLimitService.checkAndTrackKeyUserSession(2, 1, "sess", 10, 2);
  193. expect(result.allowed).toBe(false);
  194. expect(result.rejectedBy).toBe("user");
  195. expect(result.reasonCode).toBe("RATE_LIMIT_CONCURRENT_SESSIONS_EXCEEDED");
  196. expect(result.reasonParams).toEqual({ current: 2, limit: 2, target: "user" });
  197. });
  198. it("checkAndTrackKeyUserSession:未超限时应返回 allowed 且可标记 tracked", async () => {
  199. const { RateLimitService } = await import("@/lib/rate-limit");
  200. redisClientRef.eval.mockResolvedValueOnce([1, 0, 2, 1, 2, 1]);
  201. const result = await RateLimitService.checkAndTrackKeyUserSession(2, 1, "sess", 2, 2);
  202. expect(result).toEqual({
  203. allowed: true,
  204. keyCount: 2,
  205. userCount: 2,
  206. trackedKey: true,
  207. trackedUser: true,
  208. });
  209. });
  210. it("checkAndTrackKeyUserSession: should pass SESSION_TTL_MS as ARGV[5] to Lua script", async () => {
  211. const { RateLimitService } = await import("@/lib/rate-limit");
  212. redisClientRef.eval.mockResolvedValueOnce([1, 0, 1, 1, 1, 1]);
  213. await RateLimitService.checkAndTrackKeyUserSession(2, 1, "sess", 2, 2);
  214. expect(redisClientRef.eval).toHaveBeenCalledTimes(1);
  215. const evalCall = redisClientRef.eval.mock.calls[0];
  216. // evalCall: [script, numkeys, globalKey, keyKey, userKey, sessionId, keyLimit, userLimit, now, ttlMs]
  217. // Indices: 0 1 2 3 4 5 6 7 8 9
  218. expect(evalCall.length).toBe(10);
  219. const ttlMsArg = evalCall[9];
  220. expect(ttlMsArg).toBe("300000");
  221. });
  222. it("trackUserDailyCost:fixed 模式应使用 STRING + TTL", async () => {
  223. const { RateLimitService } = await import("@/lib/rate-limit");
  224. await RateLimitService.trackUserDailyCost(1, 1.25, "00:00", "fixed");
  225. expect(pipelineCalls.some((c) => c[0] === "incrbyfloat")).toBe(true);
  226. expect(pipelineCalls.some((c) => c[0] === "expire")).toBe(true);
  227. });
  228. it("trackUserDailyCost:rolling 模式应使用 ZSET Lua 脚本", async () => {
  229. const { RateLimitService } = await import("@/lib/rate-limit");
  230. await RateLimitService.trackUserDailyCost(1, 1.25, "00:00", "rolling", { requestId: 123 });
  231. expect(redisClientRef.eval).toHaveBeenCalled();
  232. });
  233. it("checkUserRPM:达到上限时应拦截", async () => {
  234. const { RateLimitService } = await import("@/lib/rate-limit");
  235. const pipeline = makePipeline();
  236. pipeline.exec
  237. .mockResolvedValueOnce([
  238. [null, 0],
  239. [null, 5], // zcard 返回 5
  240. ])
  241. .mockResolvedValueOnce([]); // 写入 pipeline
  242. redisClientRef.pipeline.mockReturnValueOnce(pipeline);
  243. const result = await RateLimitService.checkUserRPM(1, 5);
  244. expect(result.allowed).toBe(false);
  245. expect(result.current).toBe(5);
  246. });
  247. it("checkUserRPM:未达到上限时应写入本次请求并放行", async () => {
  248. const { RateLimitService } = await import("@/lib/rate-limit");
  249. const readPipeline = makePipeline();
  250. readPipeline.exec.mockResolvedValueOnce([
  251. [null, 0],
  252. [null, 3], // zcard 返回 3
  253. ]);
  254. const writePipeline = makePipeline();
  255. writePipeline.exec.mockResolvedValueOnce([]);
  256. redisClientRef.pipeline.mockReturnValueOnce(readPipeline).mockReturnValueOnce(writePipeline);
  257. const result = await RateLimitService.checkUserRPM(1, 5);
  258. expect(result.allowed).toBe(true);
  259. expect(result.current).toBe(4);
  260. expect(writePipeline.zadd).toHaveBeenCalledTimes(1);
  261. });
  262. it("checkRpmLimit:user 类型应复用 checkUserRPM 逻辑", async () => {
  263. const { RateLimitService } = await import("@/lib/rate-limit");
  264. const readPipeline = makePipeline();
  265. readPipeline.exec.mockResolvedValueOnce([
  266. [null, 0],
  267. [null, 1],
  268. ]);
  269. const writePipeline = makePipeline();
  270. writePipeline.exec.mockResolvedValueOnce([]);
  271. redisClientRef.pipeline.mockReturnValueOnce(readPipeline).mockReturnValueOnce(writePipeline);
  272. const result = await RateLimitService.checkRpmLimit(1, "user", 2);
  273. expect(result.allowed).toBe(true);
  274. expect(result.current).toBe(2);
  275. });
  276. it("getCurrentCostBatch:providerIds 为空时应返回空 Map", async () => {
  277. const { RateLimitService } = await import("@/lib/rate-limit");
  278. const result = await RateLimitService.getCurrentCostBatch([], new Map());
  279. expect(result.size).toBe(0);
  280. });
  281. it("getCurrentCostBatch:Redis 非 ready 时应返回默认 0", async () => {
  282. const { RateLimitService } = await import("@/lib/rate-limit");
  283. redisClientRef.status = "end";
  284. const result = await RateLimitService.getCurrentCostBatch([1], new Map());
  285. expect(result.get(1)).toEqual({ cost5h: 0, costDaily: 0, costWeekly: 0, costMonthly: 0 });
  286. });
  287. it("getCurrentCostBatch:应按 pipeline 返回解析 5h/daily/weekly/monthly", async () => {
  288. const { RateLimitService } = await import("@/lib/rate-limit");
  289. const pipeline = makePipeline();
  290. // queryMeta: 5h(eval), daily(get fixed), weekly(get), monthly(get)
  291. pipeline.exec.mockResolvedValueOnce([
  292. [null, "1.5"],
  293. [null, "2.5"],
  294. [null, "3.5"],
  295. [null, "4.5"],
  296. ]);
  297. redisClientRef.pipeline.mockReturnValueOnce(pipeline);
  298. const dailyResetConfigs = new Map<
  299. number,
  300. { resetTime?: string | null; resetMode?: string | null }
  301. >();
  302. dailyResetConfigs.set(1, { resetTime: "00:00", resetMode: "fixed" });
  303. const result = await RateLimitService.getCurrentCostBatch([1], dailyResetConfigs);
  304. expect(result.get(1)).toEqual({
  305. cost5h: 1.5,
  306. costDaily: 2.5,
  307. costWeekly: 3.5,
  308. costMonthly: 4.5,
  309. });
  310. });
  311. it("checkCostLimits:5h 滚动窗口超限时应返回 not allowed", async () => {
  312. const { RateLimitService } = await import("@/lib/rate-limit");
  313. redisClientRef.eval.mockResolvedValueOnce("11");
  314. const result = await RateLimitService.checkCostLimits(1, "provider", {
  315. limit_5h_usd: 10,
  316. limit_daily_usd: null,
  317. limit_weekly_usd: null,
  318. limit_monthly_usd: null,
  319. });
  320. expect(result.allowed).toBe(false);
  321. expect(result.reason).toContain("供应商 5小时消费上限已达到(11.0000/10)");
  322. });
  323. it("checkCostLimits:daily rolling cache miss 时应回退 DB 并 warm ZSET", async () => {
  324. const { RateLimitService } = await import("@/lib/rate-limit");
  325. redisClientRef.eval.mockResolvedValueOnce("0");
  326. redisClientRef.exists.mockResolvedValueOnce(0);
  327. statisticsMock.findProviderCostEntriesInTimeRange.mockResolvedValueOnce([
  328. { id: 101, createdAt: new Date(nowMs - 60_000), costUsd: 3 },
  329. { id: 102, createdAt: new Date(nowMs - 30_000), costUsd: 9 },
  330. ]);
  331. const result = await RateLimitService.checkCostLimits(9, "provider", {
  332. limit_5h_usd: null,
  333. limit_daily_usd: 10,
  334. daily_reset_mode: "rolling",
  335. daily_reset_time: "00:00",
  336. limit_weekly_usd: null,
  337. limit_monthly_usd: null,
  338. });
  339. expect(result.allowed).toBe(false);
  340. expect(result.reason).toContain("供应商 每日消费上限已达到(12.0000/10)");
  341. expect(pipelineCalls.some((c) => c[0] === "zadd")).toBe(true);
  342. });
  343. it("getCurrentCost:daily fixed cache hit 时应直接返回当前值", async () => {
  344. const { RateLimitService } = await import("@/lib/rate-limit");
  345. redisClientRef.get.mockImplementation(async (key: string) => {
  346. if (key === "provider:9:cost_daily_0000") return "7.5";
  347. return null;
  348. });
  349. const current = await RateLimitService.getCurrentCost(9, "provider", "daily", "00:00", "fixed");
  350. expect(current).toBeCloseTo(7.5, 10);
  351. });
  352. it("getCurrentCost:daily rolling cache miss 时应从 DB 重建并返回", async () => {
  353. const { RateLimitService } = await import("@/lib/rate-limit");
  354. redisClientRef.eval.mockResolvedValueOnce("0");
  355. redisClientRef.exists.mockResolvedValueOnce(0);
  356. statisticsMock.findProviderCostEntriesInTimeRange.mockResolvedValueOnce([
  357. { id: 101, createdAt: new Date(nowMs - 60_000), costUsd: 2 },
  358. { id: 102, createdAt: new Date(nowMs - 30_000), costUsd: 3 },
  359. ]);
  360. const current = await RateLimitService.getCurrentCost(
  361. 9,
  362. "provider",
  363. "daily",
  364. "00:00",
  365. "rolling"
  366. );
  367. expect(current).toBeCloseTo(5, 10);
  368. expect(pipelineCalls.some((c) => c[0] === "zadd")).toBe(true);
  369. });
  370. it("trackCost:fixed 模式应写入 key/provider 的 daily+weekly+monthly(STRING)", async () => {
  371. const { RateLimitService } = await import("@/lib/rate-limit");
  372. await RateLimitService.trackCost(1, 9, "sess", 1.25, {
  373. keyResetMode: "fixed",
  374. providerResetMode: "fixed",
  375. keyResetTime: "00:00",
  376. providerResetTime: "00:00",
  377. requestId: 123,
  378. createdAtMs: nowMs,
  379. });
  380. // 5h 的 Lua 脚本至少会执行两次(key/provider)
  381. expect(redisClientRef.eval).toHaveBeenCalled();
  382. expect(pipelineCalls.filter((c) => c[0] === "incrbyfloat").length).toBeGreaterThanOrEqual(4);
  383. expect(pipelineCalls.filter((c) => c[0] === "expire").length).toBeGreaterThanOrEqual(4);
  384. });
  385. it("trackCost:rolling 模式应写入 key/provider 的 daily_rolling(ZSET)", async () => {
  386. const { RateLimitService } = await import("@/lib/rate-limit");
  387. await RateLimitService.trackCost(1, 9, "sess", 1.25, {
  388. keyResetMode: "rolling",
  389. providerResetMode: "rolling",
  390. requestId: 123,
  391. createdAtMs: nowMs,
  392. });
  393. const evalArgs = redisClientRef.eval.mock.calls.map((c: unknown[]) => String(c[2]));
  394. expect(evalArgs.some((k) => k === "key:1:cost_daily_rolling")).toBe(true);
  395. expect(evalArgs.some((k) => k === "provider:9:cost_daily_rolling")).toBe(true);
  396. });
  397. it("getCurrentCostBatch:pipeline.exec 返回 null 时应返回默认值", async () => {
  398. const { RateLimitService } = await import("@/lib/rate-limit");
  399. const pipeline = makePipeline();
  400. pipeline.exec.mockResolvedValueOnce(null);
  401. redisClientRef.pipeline.mockReturnValueOnce(pipeline);
  402. const result = await RateLimitService.getCurrentCostBatch([1], new Map());
  403. expect(result.get(1)).toEqual({ cost5h: 0, costDaily: 0, costWeekly: 0, costMonthly: 0 });
  404. });
  405. it("getCurrentCostBatch:单个 query 出错时应跳过该项", async () => {
  406. const { RateLimitService } = await import("@/lib/rate-limit");
  407. const pipeline = makePipeline();
  408. pipeline.exec.mockResolvedValueOnce([
  409. [new Error("boom"), null],
  410. [null, "2.5"],
  411. [null, "3.5"],
  412. [null, "4.5"],
  413. ]);
  414. redisClientRef.pipeline.mockReturnValueOnce(pipeline);
  415. const result = await RateLimitService.getCurrentCostBatch([1], new Map());
  416. // 5h 出错,保持默认 0,其余正常
  417. expect(result.get(1)).toEqual({ cost5h: 0, costDaily: 2.5, costWeekly: 3.5, costMonthly: 4.5 });
  418. });
  419. });