my-usage-concurrent-inherit.test.ts 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. import { beforeEach, describe, expect, it, vi } from "vitest";
  2. const getSessionMock = vi.fn();
  3. vi.mock("@/lib/auth", () => ({
  4. getSession: getSessionMock,
  5. }));
  6. const getKeySessionCountMock = vi.fn(async () => 2);
  7. vi.mock("@/lib/session-tracker", () => ({
  8. SessionTracker: {
  9. getKeySessionCount: getKeySessionCountMock,
  10. },
  11. }));
  12. const getTimeRangeForPeriodWithModeMock = vi.fn(async () => ({
  13. startTime: new Date("2026-02-11T00:00:00.000Z"),
  14. endTime: new Date("2026-02-12T00:00:00.000Z"),
  15. }));
  16. const getTimeRangeForPeriodMock = vi.fn(async () => ({
  17. startTime: new Date("2026-02-11T00:00:00.000Z"),
  18. endTime: new Date("2026-02-12T00:00:00.000Z"),
  19. }));
  20. vi.mock("@/lib/rate-limit/time-utils", () => ({
  21. getTimeRangeForPeriodWithMode: getTimeRangeForPeriodWithModeMock,
  22. getTimeRangeForPeriod: getTimeRangeForPeriodMock,
  23. }));
  24. const statisticsMock = {
  25. sumUserCostInTimeRange: vi.fn(async () => 0),
  26. sumUserTotalCost: vi.fn(async () => 0),
  27. sumUserQuotaCosts: vi.fn(async () => ({
  28. cost5h: 0,
  29. costDaily: 0,
  30. costWeekly: 0,
  31. costMonthly: 0,
  32. costTotal: 0,
  33. })),
  34. sumKeyCostInTimeRange: vi.fn(async () => 0),
  35. sumKeyQuotaCostsById: vi.fn(async () => ({
  36. cost5h: 0,
  37. costDaily: 0,
  38. costWeekly: 0,
  39. costMonthly: 0,
  40. costTotal: 0,
  41. })),
  42. };
  43. vi.mock("@/repository/statistics", () => statisticsMock);
  44. const whereMock = vi.fn(async () => [{ id: 1 }]);
  45. const fromMock = vi.fn(() => ({ where: whereMock }));
  46. const selectMock = vi.fn(() => ({ from: fromMock }));
  47. vi.mock("@/drizzle/db", () => ({
  48. db: {
  49. select: selectMock,
  50. },
  51. }));
  52. vi.mock("@/lib/logger", () => ({
  53. logger: {
  54. warn: vi.fn(),
  55. error: vi.fn(),
  56. },
  57. }));
  58. function createSession(params: {
  59. keyLimitConcurrentSessions: number | null;
  60. userLimitConcurrentSessions: number | null;
  61. }) {
  62. return {
  63. key: {
  64. id: 1,
  65. key: "sk-test",
  66. name: "k",
  67. dailyResetTime: "00:00",
  68. dailyResetMode: "fixed",
  69. limit5hUsd: null,
  70. limitDailyUsd: null,
  71. limitWeeklyUsd: null,
  72. limitMonthlyUsd: null,
  73. limitTotalUsd: null,
  74. limitConcurrentSessions: params.keyLimitConcurrentSessions,
  75. providerGroup: null,
  76. isEnabled: true,
  77. expiresAt: null,
  78. },
  79. user: {
  80. id: 10,
  81. name: "u",
  82. dailyResetTime: "00:00",
  83. dailyResetMode: "fixed",
  84. limit5hUsd: null,
  85. dailyQuota: null,
  86. limitWeeklyUsd: null,
  87. limitMonthlyUsd: null,
  88. limitTotalUsd: null,
  89. limitConcurrentSessions: params.userLimitConcurrentSessions,
  90. rpm: null,
  91. providerGroup: null,
  92. isEnabled: true,
  93. expiresAt: null,
  94. allowedModels: [],
  95. allowedClients: [],
  96. },
  97. };
  98. }
  99. describe("getMyQuota - concurrent limit inheritance", () => {
  100. beforeEach(() => {
  101. vi.clearAllMocks();
  102. getSessionMock.mockResolvedValue(
  103. createSession({ keyLimitConcurrentSessions: 0, userLimitConcurrentSessions: 15 })
  104. );
  105. });
  106. it("Key 并发为 0 时应回退到 User 并发上限", async () => {
  107. const { getMyQuota } = await import("@/actions/my-usage");
  108. const result = await getMyQuota();
  109. expect(result.ok).toBe(true);
  110. if (result.ok) {
  111. expect(result.data.keyLimitConcurrentSessions).toBe(15);
  112. }
  113. });
  114. it("Key 并发为正数时应优先使用 Key 自身上限", async () => {
  115. getSessionMock.mockResolvedValue(
  116. createSession({ keyLimitConcurrentSessions: 5, userLimitConcurrentSessions: 15 })
  117. );
  118. const { getMyQuota } = await import("@/actions/my-usage");
  119. const result = await getMyQuota();
  120. expect(result.ok).toBe(true);
  121. if (result.ok) {
  122. expect(result.data.keyLimitConcurrentSessions).toBe(5);
  123. }
  124. });
  125. it("Key=0 且 User=0 时应返回 0(无限制)", async () => {
  126. getSessionMock.mockResolvedValue(
  127. createSession({ keyLimitConcurrentSessions: 0, userLimitConcurrentSessions: 0 })
  128. );
  129. const { getMyQuota } = await import("@/actions/my-usage");
  130. const result = await getMyQuota();
  131. expect(result.ok).toBe(true);
  132. if (result.ok) {
  133. expect(result.data.keyLimitConcurrentSessions).toBe(0);
  134. }
  135. });
  136. });