key-usage-token-overflow.test.ts 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. import { describe, expect, test, vi } from "vitest";
  2. // 禁用 tests/setup.ts 中基于 DSN/Redis 的默认同步与清理协调,避免无关依赖引入。
  3. process.env.DSN = "";
  4. process.env.AUTO_CLEANUP_TEST_DATA = "false";
  5. function sqlToString(sqlObj: unknown): string {
  6. const visited = new Set<unknown>();
  7. const walk = (node: unknown): string => {
  8. if (!node || visited.has(node)) return "";
  9. visited.add(node);
  10. if (typeof node === "string") return node;
  11. if (typeof node === "object") {
  12. const anyNode = node as any;
  13. if (Array.isArray(anyNode)) {
  14. return anyNode.map(walk).join("");
  15. }
  16. if (anyNode.value) {
  17. if (Array.isArray(anyNode.value)) {
  18. return anyNode.value.map(String).join("");
  19. }
  20. return String(anyNode.value);
  21. }
  22. if (anyNode.queryChunks) {
  23. return walk(anyNode.queryChunks);
  24. }
  25. }
  26. return "";
  27. };
  28. return walk(sqlObj);
  29. }
  30. function createThenableQuery<T>(result: T) {
  31. const query: any = Promise.resolve(result);
  32. query.from = vi.fn(() => query);
  33. query.leftJoin = vi.fn(() => query);
  34. query.innerJoin = vi.fn(() => query);
  35. query.where = vi.fn(() => query);
  36. query.groupBy = vi.fn(() => query);
  37. query.orderBy = vi.fn(() => query);
  38. query.limit = vi.fn(() => query);
  39. query.offset = vi.fn(() => query);
  40. return query;
  41. }
  42. describe("Key usage token aggregation overflow", () => {
  43. test("findKeyUsageTodayBatch: token sum 不应使用 ::int", async () => {
  44. vi.resetModules();
  45. const selectArgs: unknown[] = [];
  46. const selectMock = vi.fn((selection: unknown) => {
  47. selectArgs.push(selection);
  48. return createThenableQuery([]);
  49. });
  50. vi.doMock("@/drizzle/db", () => ({
  51. db: {
  52. select: selectMock,
  53. // 给 tests/setup.ts 的 afterAll 清理逻辑一个可用的 execute
  54. execute: vi.fn(async () => ({ count: 0 })),
  55. },
  56. }));
  57. const { findKeyUsageTodayBatch } = await import("@/repository/key");
  58. await findKeyUsageTodayBatch([1]);
  59. expect(selectArgs).toHaveLength(1);
  60. const selection = selectArgs[0] as Record<string, unknown>;
  61. const totalTokensSql = sqlToString(selection.totalTokens).toLowerCase();
  62. expect(totalTokensSql).not.toContain("::int");
  63. expect(totalTokensSql).not.toContain("::int4");
  64. expect(totalTokensSql).toContain("double precision");
  65. });
  66. test("findKeysWithStatisticsBatch: unnest 必须使用 ARRAY[] 而非行构造器", async () => {
  67. vi.resetModules();
  68. const executeSqlArgs: unknown[] = [];
  69. const selectQueue: any[] = [];
  70. selectQueue.push(
  71. createThenableQuery([
  72. {
  73. id: 10,
  74. userId: 1,
  75. key: "k",
  76. name: "n",
  77. isEnabled: true,
  78. expiresAt: null,
  79. canLoginWebUi: true,
  80. limit5hUsd: null,
  81. limitDailyUsd: null,
  82. dailyResetMode: "fixed",
  83. dailyResetTime: "00:00",
  84. limitWeeklyUsd: null,
  85. limitMonthlyUsd: null,
  86. limitTotalUsd: null,
  87. limitConcurrentSessions: 0,
  88. providerGroup: null,
  89. cacheTtlPreference: null,
  90. createdAt: new Date("2024-01-01T00:00:00.000Z"),
  91. updatedAt: new Date("2024-01-01T00:00:00.000Z"),
  92. deletedAt: null,
  93. },
  94. ])
  95. );
  96. selectQueue.push(createThenableQuery([]));
  97. const fallbackSelect = createThenableQuery<unknown[]>([]);
  98. const selectMock = vi.fn((_selection: unknown) => selectQueue.shift() ?? fallbackSelect);
  99. vi.doMock("@/drizzle/db", () => ({
  100. db: {
  101. select: selectMock,
  102. execute: vi.fn(async (sqlObj: unknown) => {
  103. executeSqlArgs.push(sqlObj);
  104. return [];
  105. }),
  106. },
  107. }));
  108. const { findKeysWithStatisticsBatch } = await import("@/repository/key");
  109. await findKeysWithStatisticsBatch([1]);
  110. expect(executeSqlArgs.length).toBeGreaterThan(0);
  111. const lateralJoinSql = sqlToString(executeSqlArgs[0]).toLowerCase();
  112. expect(lateralJoinSql).toContain("array[");
  113. expect(lateralJoinSql).not.toContain("unnest((");
  114. expect(lateralJoinSql).toContain("key_val");
  115. });
  116. test("findKeysWithStatisticsBatch: last usage 排序不能附带 NULLS LAST", async () => {
  117. vi.resetModules();
  118. const executeSqlArgs: unknown[] = [];
  119. const selectQueue: any[] = [];
  120. selectQueue.push(
  121. createThenableQuery([
  122. {
  123. id: 10,
  124. userId: 1,
  125. key: "k",
  126. name: "n",
  127. isEnabled: true,
  128. expiresAt: null,
  129. canLoginWebUi: true,
  130. limit5hUsd: null,
  131. limitDailyUsd: null,
  132. dailyResetMode: "fixed",
  133. dailyResetTime: "00:00",
  134. limitWeeklyUsd: null,
  135. limitMonthlyUsd: null,
  136. limitTotalUsd: null,
  137. limitConcurrentSessions: 0,
  138. providerGroup: null,
  139. cacheTtlPreference: null,
  140. createdAt: new Date("2024-01-01T00:00:00.000Z"),
  141. updatedAt: new Date("2024-01-01T00:00:00.000Z"),
  142. deletedAt: null,
  143. },
  144. ])
  145. );
  146. selectQueue.push(createThenableQuery([]));
  147. const fallbackSelect = createThenableQuery<unknown[]>([]);
  148. const selectMock = vi.fn((_selection: unknown) => selectQueue.shift() ?? fallbackSelect);
  149. vi.doMock("@/drizzle/db", () => ({
  150. db: {
  151. select: selectMock,
  152. execute: vi.fn(async (sqlObj: unknown) => {
  153. executeSqlArgs.push(sqlObj);
  154. return [];
  155. }),
  156. },
  157. }));
  158. const { findKeysWithStatisticsBatch } = await import("@/repository/key");
  159. await findKeysWithStatisticsBatch([1]);
  160. expect(executeSqlArgs.length).toBeGreaterThan(0);
  161. const lastUsageSql = sqlToString(executeSqlArgs[0]).toLowerCase();
  162. expect(lastUsageSql).toContain("order by ul.created_at desc");
  163. expect(lastUsageSql).not.toContain("nulls last");
  164. });
  165. test("findKeysWithStatisticsBatch: modelStats token sum 不应使用 ::int", async () => {
  166. vi.resetModules();
  167. const selectArgs: unknown[] = [];
  168. const selectQueue: any[] = [];
  169. selectQueue.push(
  170. createThenableQuery([
  171. {
  172. id: 10,
  173. userId: 1,
  174. key: "k",
  175. name: "n",
  176. isEnabled: true,
  177. expiresAt: null,
  178. canLoginWebUi: true,
  179. limit5hUsd: null,
  180. limitDailyUsd: null,
  181. dailyResetMode: "fixed",
  182. dailyResetTime: "00:00",
  183. limitWeeklyUsd: null,
  184. limitMonthlyUsd: null,
  185. limitTotalUsd: null,
  186. limitConcurrentSessions: 0,
  187. providerGroup: null,
  188. cacheTtlPreference: null,
  189. createdAt: new Date("2024-01-01T00:00:00.000Z"),
  190. updatedAt: new Date("2024-01-01T00:00:00.000Z"),
  191. deletedAt: null,
  192. },
  193. ])
  194. );
  195. selectQueue.push(createThenableQuery([]));
  196. const fallbackSelect = createThenableQuery<unknown[]>([]);
  197. const selectMock = vi.fn((selection: unknown) => {
  198. selectArgs.push(selection);
  199. return selectQueue.shift() ?? fallbackSelect;
  200. });
  201. vi.doMock("@/drizzle/db", () => ({
  202. db: {
  203. select: selectMock,
  204. execute: vi.fn(async () => []),
  205. },
  206. }));
  207. const { findKeysWithStatisticsBatch } = await import("@/repository/key");
  208. await findKeysWithStatisticsBatch([1]);
  209. const selection = selectArgs.find((s): s is Record<string, unknown> => {
  210. if (!s || typeof s !== "object") return false;
  211. return "inputTokens" in s && "cacheReadTokens" in s;
  212. });
  213. expect(selection).toBeTruthy();
  214. for (const field of ["inputTokens", "outputTokens", "cacheCreationTokens", "cacheReadTokens"]) {
  215. const tokenSql = sqlToString(selection?.[field]).toLowerCase();
  216. expect(tokenSql).not.toContain("::int");
  217. expect(tokenSql).not.toContain("::int4");
  218. expect(tokenSql).toContain("double precision");
  219. }
  220. });
  221. });