2
0

my-usage-token-aggregation.test.ts 5.8 KB


  1. import { describe, expect, test, vi } from "vitest";
  2. // 禁用 tests/setup.ts 中基于 DSN/Redis 的默认同步与清理协调,避免无关依赖引入。
  3. process.env.DSN = "";
  4. process.env.AUTO_CLEANUP_TEST_DATA = "false";
  5. function sqlToString(sqlObj: unknown): string {
  6. const visited = new Set<unknown>();
  7. const walk = (node: unknown): string => {
  8. if (!node || visited.has(node)) return "";
  9. visited.add(node);
  10. if (typeof node === "string") return node;
  11. if (typeof node === "object") {
  12. const anyNode = node as any;
  13. if (Array.isArray(anyNode)) {
  14. return anyNode.map(walk).join("");
  15. }
  16. if (anyNode.value) {
  17. if (Array.isArray(anyNode.value)) {
  18. return anyNode.value.map(String).join("");
  19. }
  20. return String(anyNode.value);
  21. }
  22. if (anyNode.queryChunks) {
  23. return walk(anyNode.queryChunks);
  24. }
  25. }
  26. return "";
  27. };
  28. return walk(sqlObj);
  29. }
  30. function createThenableQuery<T>(result: T) {
  31. const query: any = Promise.resolve(result);
  32. query.from = vi.fn(() => query);
  33. query.innerJoin = vi.fn(() => query);
  34. query.leftJoin = vi.fn(() => query);
  35. query.where = vi.fn(() => query);
  36. query.groupBy = vi.fn(() => query);
  37. query.orderBy = vi.fn(() => query);
  38. query.limit = vi.fn(() => query);
  39. query.offset = vi.fn(() => query);
  40. return query;
  41. }
  42. const mocks = vi.hoisted(() => ({
  43. getSession: vi.fn(),
  44. getSystemSettings: vi.fn(),
  45. getEnvConfig: vi.fn(),
  46. getTimeRangeForPeriodWithMode: vi.fn(),
  47. findUsageLogsStats: vi.fn(),
  48. select: vi.fn(),
  49. execute: vi.fn(async () => ({ count: 0 })),
  50. }));
  51. vi.mock("@/lib/auth", () => ({
  52. getSession: mocks.getSession,
  53. }));
  54. vi.mock("@/repository/system-config", () => ({
  55. getSystemSettings: mocks.getSystemSettings,
  56. }));
  57. vi.mock("@/lib/config", () => ({
  58. getEnvConfig: mocks.getEnvConfig,
  59. }));
  60. vi.mock("@/lib/rate-limit/time-utils", () => ({
  61. getTimeRangeForPeriodWithMode: mocks.getTimeRangeForPeriodWithMode,
  62. }));
  63. vi.mock("@/repository/usage-logs", async (importOriginal) => {
  64. const actual = await importOriginal<typeof import("@/repository/usage-logs")>();
  65. return {
  66. ...actual,
  67. findUsageLogsStats: mocks.findUsageLogsStats,
  68. };
  69. });
  70. vi.mock("@/drizzle/db", () => ({
  71. db: {
  72. select: mocks.select,
  73. execute: mocks.execute,
  74. },
  75. }));
  76. function expectNoIntTokenSum(selection: Record<string, unknown>, field: string) {
  77. const tokenSql = sqlToString(selection[field]).toLowerCase();
  78. expect(tokenSql).toContain("sum");
  79. expect(tokenSql).not.toContain("::int");
  80. expect(tokenSql).not.toContain("::int4");
  81. expect(tokenSql).toContain("double precision");
  82. }
  83. describe("my-usage token aggregation", () => {
  84. test("getMyTodayStats: token sum 不应使用 ::int", async () => {
  85. vi.resetModules();
  86. const capturedSelections: Array<Record<string, unknown>> = [];
  87. const selectQueue: any[] = [];
  88. selectQueue.push(
  89. createThenableQuery([
  90. {
  91. calls: 0,
  92. inputTokens: 0,
  93. outputTokens: 0,
  94. costUsd: "0",
  95. },
  96. ])
  97. );
  98. selectQueue.push(createThenableQuery([]));
  99. mocks.select.mockImplementation((selection: unknown) => {
  100. capturedSelections.push(selection as Record<string, unknown>);
  101. return selectQueue.shift() ?? createThenableQuery([]);
  102. });
  103. mocks.getTimeRangeForPeriodWithMode.mockResolvedValue({
  104. startTime: new Date("2024-01-01T00:00:00.000Z"),
  105. endTime: new Date("2024-01-02T00:00:00.000Z"),
  106. });
  107. mocks.getSession.mockResolvedValue({
  108. key: {
  109. id: 1,
  110. key: "k",
  111. dailyResetTime: "00:00",
  112. dailyResetMode: "fixed",
  113. },
  114. user: { id: 1 },
  115. });
  116. mocks.getSystemSettings.mockResolvedValue({
  117. currencyDisplay: "USD",
  118. billingModelSource: "original",
  119. });
  120. const { getMyTodayStats } = await import("@/actions/my-usage");
  121. const res = await getMyTodayStats();
  122. expect(res.ok).toBe(true);
  123. expect(capturedSelections.length).toBeGreaterThanOrEqual(2);
  124. expectNoIntTokenSum(capturedSelections[0], "inputTokens");
  125. expectNoIntTokenSum(capturedSelections[0], "outputTokens");
  126. expectNoIntTokenSum(capturedSelections[1], "inputTokens");
  127. expectNoIntTokenSum(capturedSelections[1], "outputTokens");
  128. });
  129. test("getMyStatsSummary: token sum 不应使用 ::int", async () => {
  130. vi.resetModules();
  131. const capturedSelections: Array<Record<string, unknown>> = [];
  132. const selectQueue: any[] = [];
  133. selectQueue.push(createThenableQuery([]));
  134. selectQueue.push(createThenableQuery([]));
  135. mocks.select.mockImplementation((selection: unknown) => {
  136. capturedSelections.push(selection as Record<string, unknown>);
  137. return selectQueue.shift() ?? createThenableQuery([]);
  138. });
  139. mocks.getEnvConfig.mockReturnValue({ TZ: "UTC" });
  140. mocks.getSession.mockResolvedValue({
  141. key: { id: 1, key: "k" },
  142. user: { id: 1 },
  143. });
  144. mocks.getSystemSettings.mockResolvedValue({
  145. currencyDisplay: "USD",
  146. billingModelSource: "original",
  147. });
  148. mocks.findUsageLogsStats.mockResolvedValue({
  149. totalRequests: 0,
  150. totalCost: 0,
  151. totalTokens: 0,
  152. totalInputTokens: 0,
  153. totalOutputTokens: 0,
  154. totalCacheCreationTokens: 0,
  155. totalCacheReadTokens: 0,
  156. totalCacheCreation5mTokens: 0,
  157. totalCacheCreation1hTokens: 0,
  158. });
  159. const { getMyStatsSummary } = await import("@/actions/my-usage");
  160. const res = await getMyStatsSummary({ startDate: "2024-01-01", endDate: "2024-01-01" });
  161. expect(res.ok).toBe(true);
  162. expect(capturedSelections).toHaveLength(2);
  163. for (const selection of capturedSelections) {
  164. expectNoIntTokenSum(selection, "inputTokens");
  165. expectNoIntTokenSum(selection, "outputTokens");
  166. expectNoIntTokenSum(selection, "cacheCreationTokens");
  167. expectNoIntTokenSum(selection, "cacheReadTokens");
  168. }
  169. });
  170. });