model-prices.test.ts 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493
  1. import { beforeEach, describe, expect, it, vi } from "vitest";
  2. import type { ModelPrice, ModelPriceData } from "@/types/model-price";
  3. // Mock dependencies
  4. const getSessionMock = vi.fn();
  5. const revalidatePathMock = vi.fn();
  6. // Repository mocks
  7. const findLatestPriceByModelMock = vi.fn();
  8. const findAllLatestPricesMock = vi.fn();
  9. const createModelPriceMock = vi.fn();
  10. const upsertModelPriceMock = vi.fn();
  11. const deleteModelPriceByNameMock = vi.fn();
  12. const findAllManualPricesMock = vi.fn();
  13. // Price sync mock
  14. const fetchCloudPriceTableTomlMock = vi.fn();
  15. vi.mock("@/lib/auth", () => ({
  16. getSession: () => getSessionMock(),
  17. }));
  18. vi.mock("next/cache", () => ({
  19. revalidatePath: () => revalidatePathMock(),
  20. }));
  21. vi.mock("@/lib/logger", () => ({
  22. logger: {
  23. trace: vi.fn(),
  24. debug: vi.fn(),
  25. info: vi.fn(),
  26. warn: vi.fn(),
  27. error: vi.fn(),
  28. },
  29. }));
  30. vi.mock("@/repository/model-price", () => ({
  31. findLatestPriceByModel: () => findLatestPriceByModelMock(),
  32. createModelPrice: (...args: unknown[]) => createModelPriceMock(...args),
  33. upsertModelPrice: (...args: unknown[]) => upsertModelPriceMock(...args),
  34. deleteModelPriceByName: (...args: unknown[]) => deleteModelPriceByNameMock(...args),
  35. findAllManualPrices: () => findAllManualPricesMock(),
  36. findAllLatestPrices: () => findAllLatestPricesMock(),
  37. findAllLatestPricesPaginated: vi.fn(async () => ({
  38. data: [],
  39. total: 0,
  40. page: 1,
  41. pageSize: 50,
  42. totalPages: 0,
  43. })),
  44. hasAnyPriceRecords: vi.fn(async () => false),
  45. }));
  46. vi.mock("@/lib/price-sync/cloud-price-table", async (importOriginal) => {
  47. const actual = await importOriginal<typeof import("@/lib/price-sync/cloud-price-table")>();
  48. return {
  49. ...actual,
  50. fetchCloudPriceTableToml: (...args: unknown[]) => fetchCloudPriceTableTomlMock(...args),
  51. };
  52. });
  53. // Helper to create mock ModelPrice
  54. function makeMockPrice(
  55. modelName: string,
  56. priceData: Partial<ModelPriceData>,
  57. source: "litellm" | "manual" = "manual"
  58. ): ModelPrice {
  59. const now = new Date();
  60. return {
  61. id: Math.floor(Math.random() * 1000),
  62. modelName,
  63. priceData: {
  64. mode: "chat",
  65. input_cost_per_token: 0.000001,
  66. output_cost_per_token: 0.000002,
  67. ...priceData,
  68. },
  69. source,
  70. createdAt: now,
  71. updatedAt: now,
  72. };
  73. }
  74. describe("Model Price Actions", () => {
  75. beforeEach(() => {
  76. vi.clearAllMocks();
  77. // Default: admin session
  78. getSessionMock.mockResolvedValue({ user: { id: 1, role: "admin" } });
  79. findAllLatestPricesMock.mockResolvedValue([]);
  80. });
  81. describe("upsertSingleModelPrice", () => {
  82. it("should create a new model price for admin", async () => {
  83. const mockResult = makeMockPrice("gpt-5.2-codex", {
  84. mode: "chat",
  85. input_cost_per_token: 0.000015,
  86. output_cost_per_token: 0.00006,
  87. });
  88. upsertModelPriceMock.mockResolvedValue(mockResult);
  89. const { upsertSingleModelPrice } = await import("@/actions/model-prices");
  90. const result = await upsertSingleModelPrice({
  91. modelName: "gpt-5.2-codex",
  92. mode: "chat",
  93. litellmProvider: "openai",
  94. inputCostPerToken: 0.000015,
  95. outputCostPerToken: 0.00006,
  96. });
  97. expect(result.ok).toBe(true);
  98. expect(result.data?.modelName).toBe("gpt-5.2-codex");
  99. expect(upsertModelPriceMock).toHaveBeenCalledWith(
  100. "gpt-5.2-codex",
  101. expect.objectContaining({
  102. mode: "chat",
  103. litellm_provider: "openai",
  104. input_cost_per_token: 0.000015,
  105. output_cost_per_token: 0.00006,
  106. })
  107. );
  108. });
  109. it("should reject empty model name", async () => {
  110. const { upsertSingleModelPrice } = await import("@/actions/model-prices");
  111. const result = await upsertSingleModelPrice({
  112. modelName: " ",
  113. mode: "chat",
  114. });
  115. expect(result.ok).toBe(false);
  116. expect(result.error).toContain("模型名称");
  117. expect(upsertModelPriceMock).not.toHaveBeenCalled();
  118. });
  119. it("should reject non-admin users", async () => {
  120. getSessionMock.mockResolvedValue({ user: { id: 2, role: "user" } });
  121. const { upsertSingleModelPrice } = await import("@/actions/model-prices");
  122. const result = await upsertSingleModelPrice({
  123. modelName: "test-model",
  124. mode: "chat",
  125. });
  126. expect(result.ok).toBe(false);
  127. expect(result.error).toContain("无权限");
  128. expect(upsertModelPriceMock).not.toHaveBeenCalled();
  129. });
  130. it("should handle image generation mode", async () => {
  131. const mockResult = makeMockPrice("dall-e-3", {
  132. mode: "image_generation",
  133. output_cost_per_image: 0.04,
  134. });
  135. upsertModelPriceMock.mockResolvedValue(mockResult);
  136. const { upsertSingleModelPrice } = await import("@/actions/model-prices");
  137. const result = await upsertSingleModelPrice({
  138. modelName: "dall-e-3",
  139. mode: "image_generation",
  140. litellmProvider: "openai",
  141. outputCostPerImage: 0.04,
  142. });
  143. expect(result.ok).toBe(true);
  144. expect(upsertModelPriceMock).toHaveBeenCalledWith(
  145. "dall-e-3",
  146. expect.objectContaining({
  147. mode: "image_generation",
  148. output_cost_per_image: 0.04,
  149. })
  150. );
  151. });
  152. it("should handle repository errors gracefully", async () => {
  153. upsertModelPriceMock.mockRejectedValue(new Error("Database error"));
  154. const { upsertSingleModelPrice } = await import("@/actions/model-prices");
  155. const result = await upsertSingleModelPrice({
  156. modelName: "test-model",
  157. mode: "chat",
  158. });
  159. expect(result.ok).toBe(false);
  160. expect(result.error).toBeDefined();
  161. });
  162. });
  163. describe("deleteSingleModelPrice", () => {
  164. it("should delete a model price for admin", async () => {
  165. deleteModelPriceByNameMock.mockResolvedValue(undefined);
  166. const { deleteSingleModelPrice } = await import("@/actions/model-prices");
  167. const result = await deleteSingleModelPrice("gpt-5.2-codex");
  168. expect(result.ok).toBe(true);
  169. expect(deleteModelPriceByNameMock).toHaveBeenCalledWith("gpt-5.2-codex");
  170. });
  171. it("should reject empty model name", async () => {
  172. const { deleteSingleModelPrice } = await import("@/actions/model-prices");
  173. const result = await deleteSingleModelPrice("");
  174. expect(result.ok).toBe(false);
  175. expect(result.error).toContain("模型名称");
  176. expect(deleteModelPriceByNameMock).not.toHaveBeenCalled();
  177. });
  178. it("should reject non-admin users", async () => {
  179. getSessionMock.mockResolvedValue({ user: { id: 2, role: "user" } });
  180. const { deleteSingleModelPrice } = await import("@/actions/model-prices");
  181. const result = await deleteSingleModelPrice("test-model");
  182. expect(result.ok).toBe(false);
  183. expect(result.error).toContain("无权限");
  184. expect(deleteModelPriceByNameMock).not.toHaveBeenCalled();
  185. });
  186. it("should handle repository errors gracefully", async () => {
  187. deleteModelPriceByNameMock.mockRejectedValue(new Error("Database error"));
  188. const { deleteSingleModelPrice } = await import("@/actions/model-prices");
  189. const result = await deleteSingleModelPrice("test-model");
  190. expect(result.ok).toBe(false);
  191. expect(result.error).toBeDefined();
  192. });
  193. });
  194. describe("checkLiteLLMSyncConflicts", () => {
  195. it("should return no conflicts when no manual prices exist", async () => {
  196. findAllManualPricesMock.mockResolvedValue(new Map());
  197. fetchCloudPriceTableTomlMock.mockResolvedValue({
  198. ok: true,
  199. data: ['[models."claude-3-opus"]', 'mode = "chat"', "input_cost_per_token = 0.000015"].join(
  200. "\n"
  201. ),
  202. });
  203. const { checkLiteLLMSyncConflicts } = await import("@/actions/model-prices");
  204. const result = await checkLiteLLMSyncConflicts();
  205. expect(result.ok).toBe(true);
  206. expect(result.data?.hasConflicts).toBe(false);
  207. expect(result.data?.conflicts).toHaveLength(0);
  208. });
  209. it("should detect conflicts when manual prices exist in LiteLLM", async () => {
  210. const manualPrice = makeMockPrice("claude-3-opus", {
  211. mode: "chat",
  212. input_cost_per_token: 0.00001,
  213. output_cost_per_token: 0.00002,
  214. });
  215. findAllManualPricesMock.mockResolvedValue(new Map([["claude-3-opus", manualPrice]]));
  216. fetchCloudPriceTableTomlMock.mockResolvedValue({
  217. ok: true,
  218. data: [
  219. '[models."claude-3-opus"]',
  220. 'mode = "chat"',
  221. "input_cost_per_token = 0.000015",
  222. "output_cost_per_token = 0.00006",
  223. ].join("\n"),
  224. });
  225. const { checkLiteLLMSyncConflicts } = await import("@/actions/model-prices");
  226. const result = await checkLiteLLMSyncConflicts();
  227. expect(result.ok).toBe(true);
  228. expect(result.data?.hasConflicts).toBe(true);
  229. expect(result.data?.conflicts).toHaveLength(1);
  230. expect(result.data?.conflicts[0]?.modelName).toBe("claude-3-opus");
  231. });
  232. it("should not report conflicts for manual prices not in LiteLLM", async () => {
  233. const manualPrice = makeMockPrice("custom-model", {
  234. mode: "chat",
  235. input_cost_per_token: 0.00001,
  236. });
  237. findAllManualPricesMock.mockResolvedValue(new Map([["custom-model", manualPrice]]));
  238. fetchCloudPriceTableTomlMock.mockResolvedValue({
  239. ok: true,
  240. data: ['[models."claude-3-opus"]', 'mode = "chat"', "input_cost_per_token = 0.000015"].join(
  241. "\n"
  242. ),
  243. });
  244. const { checkLiteLLMSyncConflicts } = await import("@/actions/model-prices");
  245. const result = await checkLiteLLMSyncConflicts();
  246. expect(result.ok).toBe(true);
  247. expect(result.data?.hasConflicts).toBe(false);
  248. expect(result.data?.conflicts).toHaveLength(0);
  249. });
  250. it("should reject non-admin users", async () => {
  251. getSessionMock.mockResolvedValue({ user: { id: 2, role: "user" } });
  252. const { checkLiteLLMSyncConflicts } = await import("@/actions/model-prices");
  253. const result = await checkLiteLLMSyncConflicts();
  254. expect(result.ok).toBe(false);
  255. expect(result.error).toContain("无权限");
  256. });
  257. it("should handle network errors gracefully", async () => {
  258. findAllManualPricesMock.mockResolvedValue(new Map());
  259. fetchCloudPriceTableTomlMock.mockResolvedValue({
  260. ok: false,
  261. error: "云端价格表拉取失败:mock",
  262. });
  263. const { checkLiteLLMSyncConflicts } = await import("@/actions/model-prices");
  264. const result = await checkLiteLLMSyncConflicts();
  265. expect(result.ok).toBe(false);
  266. expect(result.error).toContain("云端");
  267. });
  268. it("should handle invalid TOML gracefully", async () => {
  269. findAllManualPricesMock.mockResolvedValue(new Map());
  270. fetchCloudPriceTableTomlMock.mockResolvedValue({
  271. ok: true,
  272. data: "[models\ninvalid = true",
  273. });
  274. const { checkLiteLLMSyncConflicts } = await import("@/actions/model-prices");
  275. const result = await checkLiteLLMSyncConflicts();
  276. expect(result.ok).toBe(false);
  277. expect(result.error).toContain("TOML");
  278. });
  279. });
  280. describe("processPriceTableInternal - source handling", () => {
  281. it("should skip manual prices during sync by default", async () => {
  282. const manualPrice = makeMockPrice("custom-model", {
  283. mode: "chat",
  284. input_cost_per_token: 0.00001,
  285. });
  286. findAllManualPricesMock.mockResolvedValue(new Map([["custom-model", manualPrice]]));
  287. findAllLatestPricesMock.mockResolvedValue([manualPrice]);
  288. const { processPriceTableInternal } = await import("@/actions/model-prices");
  289. const result = await processPriceTableInternal(
  290. JSON.stringify({
  291. "custom-model": {
  292. mode: "chat",
  293. input_cost_per_token: 0.000015,
  294. },
  295. })
  296. );
  297. expect(result.ok).toBe(true);
  298. expect(result.data?.skippedConflicts).toContain("custom-model");
  299. expect(result.data?.unchanged).toContain("custom-model");
  300. expect(createModelPriceMock).not.toHaveBeenCalled();
  301. });
  302. it("should overwrite manual prices when specified", async () => {
  303. const manualPrice = makeMockPrice("custom-model", {
  304. mode: "chat",
  305. input_cost_per_token: 0.00001,
  306. });
  307. findAllManualPricesMock.mockResolvedValue(new Map([["custom-model", manualPrice]]));
  308. findAllLatestPricesMock.mockResolvedValue([manualPrice]);
  309. deleteModelPriceByNameMock.mockResolvedValue(undefined);
  310. createModelPriceMock.mockResolvedValue(
  311. makeMockPrice(
  312. "custom-model",
  313. {
  314. mode: "chat",
  315. input_cost_per_token: 0.000015,
  316. },
  317. "litellm"
  318. )
  319. );
  320. const { processPriceTableInternal } = await import("@/actions/model-prices");
  321. const result = await processPriceTableInternal(
  322. JSON.stringify({
  323. "custom-model": {
  324. mode: "chat",
  325. input_cost_per_token: 0.000015,
  326. },
  327. }),
  328. ["custom-model"] // Overwrite list
  329. );
  330. expect(result.ok).toBe(true);
  331. expect(result.data?.updated).toContain("custom-model");
  332. expect(deleteModelPriceByNameMock).toHaveBeenCalledWith("custom-model");
  333. expect(createModelPriceMock).toHaveBeenCalled();
  334. });
  335. it("should add new models with litellm source", async () => {
  336. findAllManualPricesMock.mockResolvedValue(new Map());
  337. findAllLatestPricesMock.mockResolvedValue([]);
  338. createModelPriceMock.mockResolvedValue(
  339. makeMockPrice(
  340. "new-model",
  341. {
  342. mode: "chat",
  343. },
  344. "litellm"
  345. )
  346. );
  347. const { processPriceTableInternal } = await import("@/actions/model-prices");
  348. const result = await processPriceTableInternal(
  349. JSON.stringify({
  350. "new-model": {
  351. mode: "chat",
  352. input_cost_per_token: 0.000001,
  353. },
  354. })
  355. );
  356. expect(result.ok).toBe(true);
  357. expect(result.data?.added).toContain("new-model");
  358. expect(createModelPriceMock).toHaveBeenCalledWith("new-model", expect.any(Object), "litellm");
  359. });
  360. it("should skip metadata fields like sample_spec", async () => {
  361. findAllManualPricesMock.mockResolvedValue(new Map());
  362. findAllLatestPricesMock.mockResolvedValue([]);
  363. const { processPriceTableInternal } = await import("@/actions/model-prices");
  364. const result = await processPriceTableInternal(
  365. JSON.stringify({
  366. sample_spec: { description: "This is metadata" },
  367. "real-model": { mode: "chat", input_cost_per_token: 0.000001 },
  368. })
  369. );
  370. expect(result.ok).toBe(true);
  371. expect(result.data?.total).toBe(1); // Only real-model
  372. expect(result.data?.failed).not.toContain("sample_spec");
  373. });
  374. it("should skip entries without mode field", async () => {
  375. findAllManualPricesMock.mockResolvedValue(new Map());
  376. findAllLatestPricesMock.mockResolvedValue([]);
  377. const { processPriceTableInternal } = await import("@/actions/model-prices");
  378. const result = await processPriceTableInternal(
  379. JSON.stringify({
  380. "invalid-model": { input_cost_per_token: 0.000001 }, // No mode
  381. "valid-model": { mode: "chat", input_cost_per_token: 0.000001 },
  382. })
  383. );
  384. expect(result.ok).toBe(true);
  385. expect(result.data?.failed).toContain("invalid-model");
  386. });
  387. it("should ignore dangerous keys when comparing price data", async () => {
  388. const existing = makeMockPrice(
  389. "safe-model",
  390. {
  391. mode: "chat",
  392. input_cost_per_token: 0.000001,
  393. output_cost_per_token: 0.000002,
  394. },
  395. "litellm"
  396. );
  397. findAllManualPricesMock.mockResolvedValue(new Map());
  398. findAllLatestPricesMock.mockResolvedValue([existing]);
  399. const { processPriceTableInternal } = await import("@/actions/model-prices");
  400. const result = await processPriceTableInternal(
  401. JSON.stringify({
  402. "safe-model": {
  403. mode: "chat",
  404. input_cost_per_token: 0.000001,
  405. output_cost_per_token: 0.000002,
  406. constructor: { prototype: { polluted: true } },
  407. },
  408. })
  409. );
  410. expect(result.ok).toBe(true);
  411. expect(result.data?.unchanged).toContain("safe-model");
  412. expect(createModelPriceMock).not.toHaveBeenCalled();
  413. });
  414. });
  415. });