message-origin-chain.test.ts 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. import { describe, expect, test, vi } from "vitest";
  2. import type { ProviderChainItem } from "@/types/message";
  3. function sqlToString(sqlObj: unknown): string {
  4. const visited = new Set<unknown>();
  5. const walk = (node: unknown): string => {
  6. if (!node || visited.has(node)) return "";
  7. visited.add(node);
  8. if (typeof node === "string") return node;
  9. if (typeof node === "object") {
  10. const anyNode = node as any;
  11. if (Array.isArray(anyNode)) {
  12. return anyNode.map(walk).join("");
  13. }
  14. if (anyNode.name && typeof anyNode.name === "string") {
  15. return anyNode.name;
  16. }
  17. if (anyNode.value) {
  18. if (Array.isArray(anyNode.value)) {
  19. return anyNode.value.map(String).join("");
  20. }
  21. return String(anyNode.value);
  22. }
  23. if (anyNode.queryChunks) {
  24. return walk(anyNode.queryChunks);
  25. }
  26. }
  27. return "";
  28. };
  29. return walk(sqlObj);
  30. }
  31. function createThenableQuery<T>(
  32. result: T,
  33. opts?: {
  34. whereArgs?: unknown[];
  35. orderByArgs?: unknown[];
  36. limitArgs?: unknown[];
  37. }
  38. ) {
  39. const query: any = Promise.resolve(result);
  40. query.from = vi.fn(() => query);
  41. query.where = vi.fn((arg: unknown) => {
  42. opts?.whereArgs?.push(arg);
  43. return query;
  44. });
  45. query.orderBy = vi.fn((...args: unknown[]) => {
  46. opts?.orderByArgs?.push(args);
  47. return query;
  48. });
  49. query.limit = vi.fn((arg: unknown) => {
  50. opts?.limitArgs?.push(arg);
  51. return query;
  52. });
  53. return query;
  54. }
  55. describe("repository/message findSessionOriginChain", () => {
  56. test("happy path: 返回 session 首条非 warmup 的完整 providerChain", async () => {
  57. vi.resetModules();
  58. const whereArgs: unknown[] = [];
  59. const orderByArgs: unknown[] = [];
  60. const limitArgs: unknown[] = [];
  61. const chain: ProviderChainItem[] = [
  62. {
  63. id: 101,
  64. name: "provider-a",
  65. reason: "initial_selection",
  66. selectionMethod: "weighted_random",
  67. attemptNumber: 1,
  68. },
  69. ];
  70. const selectMock = vi.fn(() =>
  71. createThenableQuery([{ providerChain: chain }], { whereArgs, orderByArgs, limitArgs })
  72. );
  73. vi.doMock("@/drizzle/db", () => ({
  74. db: {
  75. select: selectMock,
  76. execute: vi.fn(async () => ({ count: 0 })),
  77. },
  78. }));
  79. const { findSessionOriginChain } = await import("@/repository/message");
  80. const result = await findSessionOriginChain("session-happy");
  81. expect(result).toEqual(chain);
  82. expect(whereArgs.length).toBeGreaterThan(0);
  83. const whereSql = sqlToString(whereArgs[0]).toLowerCase();
  84. expect(whereSql).toContain("warmup");
  85. expect(whereSql).toContain("is not null");
  86. expect(whereSql).toContain("initial_selection");
  87. expect(orderByArgs.length).toBeGreaterThan(0);
  88. const orderSql = sqlToString(orderByArgs[0]).toLowerCase();
  89. expect(orderSql).toContain("request_sequence");
  90. expect(orderSql).toContain("asc");
  91. expect(limitArgs).toEqual([1]);
  92. });
  93. test("warmup skip: 第一条为 warmup 时应返回后续首条非 warmup 的 chain", async () => {
  94. vi.resetModules();
  95. const chain: ProviderChainItem[] = [
  96. {
  97. id: 202,
  98. name: "provider-b",
  99. reason: "initial_selection",
  100. selectionMethod: "weighted_random",
  101. attemptNumber: 2,
  102. },
  103. ];
  104. const selectMock = vi.fn(() => createThenableQuery([{ providerChain: chain }]));
  105. vi.doMock("@/drizzle/db", () => ({
  106. db: {
  107. select: selectMock,
  108. execute: vi.fn(async () => ({ count: 0 })),
  109. },
  110. }));
  111. const { findSessionOriginChain } = await import("@/repository/message");
  112. const result = await findSessionOriginChain("session-warmup-first");
  113. expect(result).toEqual(chain);
  114. });
  115. test("no data: session 不存在时返回 null", async () => {
  116. vi.resetModules();
  117. const selectMock = vi.fn(() => createThenableQuery([]));
  118. vi.doMock("@/drizzle/db", () => ({
  119. db: {
  120. select: selectMock,
  121. execute: vi.fn(async () => ({ count: 0 })),
  122. },
  123. }));
  124. const { findSessionOriginChain } = await import("@/repository/message");
  125. const result = await findSessionOriginChain("session-not-found");
  126. expect(result).toBeNull();
  127. });
  128. test("all warmup: 全部请求都被 warmup 拦截时返回 null", async () => {
  129. vi.resetModules();
  130. const selectMock = vi.fn(() => createThenableQuery([]));
  131. vi.doMock("@/drizzle/db", () => ({
  132. db: {
  133. select: selectMock,
  134. execute: vi.fn(async () => ({ count: 0 })),
  135. },
  136. }));
  137. const { findSessionOriginChain } = await import("@/repository/message");
  138. const result = await findSessionOriginChain("session-all-warmup");
  139. expect(result).toBeNull();
  140. });
  141. test("null providerChain: 首条非 warmup 记录 providerChain 为空时返回 null", async () => {
  142. vi.resetModules();
  143. const selectMock = vi.fn(() => createThenableQuery([{ providerChain: null }]));
  144. vi.doMock("@/drizzle/db", () => ({
  145. db: {
  146. select: selectMock,
  147. execute: vi.fn(async () => ({ count: 0 })),
  148. },
  149. }));
  150. const { findSessionOriginChain } = await import("@/repository/message");
  151. const result = await findSessionOriginChain("session-null-provider-chain");
  152. expect(result).toBeNull();
  153. });
  154. test("all session_reuse: 全部请求都是 session_reuse 时 JSONB 过滤后返回 null", async () => {
  155. vi.resetModules();
  156. const selectMock = vi.fn(() => createThenableQuery([]));
  157. vi.doMock("@/drizzle/db", () => ({
  158. db: {
  159. select: selectMock,
  160. execute: vi.fn(async () => ({ count: 0 })),
  161. },
  162. }));
  163. const { findSessionOriginChain } = await import("@/repository/message");
  164. const result = await findSessionOriginChain("session-all-reuse");
  165. expect(result).toBeNull();
  166. });
  167. test("JSONB filter present: WHERE 子句包含 initial_selection 过滤条件", async () => {
  168. vi.resetModules();
  169. const whereArgs: unknown[] = [];
  170. const chain: ProviderChainItem[] = [
  171. {
  172. id: 301,
  173. name: "provider-c",
  174. reason: "initial_selection",
  175. selectionMethod: "weighted_random",
  176. attemptNumber: 1,
  177. },
  178. ];
  179. const selectMock = vi.fn(() => createThenableQuery([{ providerChain: chain }], { whereArgs }));
  180. vi.doMock("@/drizzle/db", () => ({
  181. db: {
  182. select: selectMock,
  183. execute: vi.fn(async () => ({ count: 0 })),
  184. },
  185. }));
  186. const { findSessionOriginChain } = await import("@/repository/message");
  187. await findSessionOriginChain("session-jsonb-filter");
  188. expect(whereArgs.length).toBeGreaterThan(0);
  189. const whereSql = sqlToString(whereArgs[0]).toLowerCase();
  190. expect(whereSql).toContain("initial_selection");
  191. expect(whereSql).toContain("@>");
  192. });
  193. });