key-usage-token-overflow.test.ts 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. import { describe, expect, test, vi } from "vitest";
  2. // 禁用 tests/setup.ts 中基于 DSN/Redis 的默认同步与清理协调,避免无关依赖引入。
  3. process.env.DSN = "";
  4. process.env.AUTO_CLEANUP_TEST_DATA = "false";
  5. function sqlToString(sqlObj: unknown): string {
  6. const visited = new Set<unknown>();
  7. const walk = (node: unknown): string => {
  8. if (!node || visited.has(node)) return "";
  9. visited.add(node);
  10. if (typeof node === "string") return node;
  11. if (typeof node === "object") {
  12. const anyNode = node as any;
  13. if (Array.isArray(anyNode)) {
  14. return anyNode.map(walk).join("");
  15. }
  16. if (anyNode.value) {
  17. if (Array.isArray(anyNode.value)) {
  18. return anyNode.value.map(String).join("");
  19. }
  20. return String(anyNode.value);
  21. }
  22. if (anyNode.queryChunks) {
  23. return walk(anyNode.queryChunks);
  24. }
  25. }
  26. return "";
  27. };
  28. return walk(sqlObj);
  29. }
  30. function createThenableQuery<T>(result: T) {
  31. const query: any = Promise.resolve(result);
  32. query.from = vi.fn(() => query);
  33. query.leftJoin = vi.fn(() => query);
  34. query.innerJoin = vi.fn(() => query);
  35. query.where = vi.fn(() => query);
  36. query.groupBy = vi.fn(() => query);
  37. query.orderBy = vi.fn(() => query);
  38. query.limit = vi.fn(() => query);
  39. query.offset = vi.fn(() => query);
  40. return query;
  41. }
  42. describe("Key usage token aggregation overflow", () => {
  43. test("findKeyUsageTodayBatch: token sum 不应使用 ::int", async () => {
  44. vi.resetModules();
  45. const selectArgs: unknown[] = [];
  46. const selectMock = vi.fn((selection: unknown) => {
  47. selectArgs.push(selection);
  48. return createThenableQuery([]);
  49. });
  50. vi.doMock("@/drizzle/db", () => ({
  51. db: {
  52. select: selectMock,
  53. // 给 tests/setup.ts 的 afterAll 清理逻辑一个可用的 execute
  54. execute: vi.fn(async () => ({ count: 0 })),
  55. },
  56. }));
  57. const { findKeyUsageTodayBatch } = await import("@/repository/key");
  58. await findKeyUsageTodayBatch([1]);
  59. expect(selectArgs).toHaveLength(1);
  60. const selection = selectArgs[0] as Record<string, unknown>;
  61. const totalTokensSql = sqlToString(selection.totalTokens).toLowerCase();
  62. expect(totalTokensSql).not.toContain("::int");
  63. expect(totalTokensSql).not.toContain("::int4");
  64. expect(totalTokensSql).toContain("double precision");
  65. });
  66. test("findKeysWithStatisticsBatch: modelStats token sum 不应使用 ::int", async () => {
  67. vi.resetModules();
  68. const selectArgs: unknown[] = [];
  69. const selectQueue: any[] = [];
  70. selectQueue.push(
  71. createThenableQuery([
  72. {
  73. id: 10,
  74. userId: 1,
  75. key: "k",
  76. name: "n",
  77. isEnabled: true,
  78. expiresAt: null,
  79. canLoginWebUi: true,
  80. limit5hUsd: null,
  81. limitDailyUsd: null,
  82. dailyResetMode: "fixed",
  83. dailyResetTime: "00:00",
  84. limitWeeklyUsd: null,
  85. limitMonthlyUsd: null,
  86. limitTotalUsd: null,
  87. limitConcurrentSessions: 0,
  88. providerGroup: null,
  89. cacheTtlPreference: null,
  90. createdAt: new Date("2024-01-01T00:00:00.000Z"),
  91. updatedAt: new Date("2024-01-01T00:00:00.000Z"),
  92. deletedAt: null,
  93. },
  94. ])
  95. );
  96. selectQueue.push(createThenableQuery([]));
  97. selectQueue.push(createThenableQuery([]));
  98. const fallbackSelect = createThenableQuery<unknown[]>([]);
  99. const selectMock = vi.fn((selection: unknown) => {
  100. selectArgs.push(selection);
  101. return selectQueue.shift() ?? fallbackSelect;
  102. });
  103. const selectDistinctOnMock = vi.fn(() => createThenableQuery([]));
  104. vi.doMock("@/drizzle/db", () => ({
  105. db: {
  106. select: selectMock,
  107. selectDistinctOn: selectDistinctOnMock,
  108. execute: vi.fn(async () => ({ count: 0 })),
  109. },
  110. }));
  111. const { findKeysWithStatisticsBatch } = await import("@/repository/key");
  112. await findKeysWithStatisticsBatch([1]);
  113. const selection = selectArgs.find((s): s is Record<string, unknown> => {
  114. if (!s || typeof s !== "object") return false;
  115. return "inputTokens" in s && "cacheReadTokens" in s;
  116. });
  117. expect(selection).toBeTruthy();
  118. for (const field of ["inputTokens", "outputTokens", "cacheCreationTokens", "cacheReadTokens"]) {
  119. const tokenSql = sqlToString(selection?.[field]).toLowerCase();
  120. expect(tokenSql).not.toContain("::int");
  121. expect(tokenSql).not.toContain("::int4");
  122. expect(tokenSql).toContain("double precision");
  123. }
  124. });
  125. });