model-prices.test.ts 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553
  1. import { beforeEach, describe, expect, it, vi } from "vitest";
  2. import type { ModelPrice, ModelPriceData } from "@/types/model-price";
  3. // Mock dependencies
  4. const getSessionMock = vi.fn();
  5. const revalidatePathMock = vi.fn();
  6. // Repository mocks
  7. const findLatestPriceByModelMock = vi.fn();
  8. const findLatestPriceByModelAndSourceMock = vi.fn();
  9. const findAllLatestPricesMock = vi.fn();
  10. const createModelPriceMock = vi.fn();
  11. const upsertModelPriceMock = vi.fn();
  12. const deleteModelPriceByNameMock = vi.fn();
  13. const findAllManualPricesMock = vi.fn();
  14. // Price sync mock
  15. const fetchCloudPriceTableTomlMock = vi.fn();
  16. vi.mock("@/lib/auth", () => ({
  17. getSession: () => getSessionMock(),
  18. }));
  19. vi.mock("next/cache", () => ({
  20. revalidatePath: () => revalidatePathMock(),
  21. }));
  22. vi.mock("@/lib/logger", () => ({
  23. logger: {
  24. trace: vi.fn(),
  25. debug: vi.fn(),
  26. info: vi.fn(),
  27. warn: vi.fn(),
  28. error: vi.fn(),
  29. },
  30. }));
  31. vi.mock("@/repository/model-price", () => ({
  32. findLatestPriceByModel: () => findLatestPriceByModelMock(),
  33. findLatestPriceByModelAndSource: (...args: unknown[]) =>
  34. findLatestPriceByModelAndSourceMock(...args),
  35. createModelPrice: (...args: unknown[]) => createModelPriceMock(...args),
  36. upsertModelPrice: (...args: unknown[]) => upsertModelPriceMock(...args),
  37. deleteModelPriceByName: (...args: unknown[]) => deleteModelPriceByNameMock(...args),
  38. findAllManualPrices: () => findAllManualPricesMock(),
  39. findAllLatestPrices: () => findAllLatestPricesMock(),
  40. findAllLatestPricesPaginated: vi.fn(async () => ({
  41. data: [],
  42. total: 0,
  43. page: 1,
  44. pageSize: 50,
  45. totalPages: 0,
  46. })),
  47. hasAnyPriceRecords: vi.fn(async () => false),
  48. }));
  49. vi.mock("@/lib/price-sync/cloud-price-table", async (importOriginal) => {
  50. const actual = await importOriginal<typeof import("@/lib/price-sync/cloud-price-table")>();
  51. return {
  52. ...actual,
  53. fetchCloudPriceTableToml: (...args: unknown[]) => fetchCloudPriceTableTomlMock(...args),
  54. };
  55. });
  56. // Helper to create mock ModelPrice
  57. function makeMockPrice(
  58. modelName: string,
  59. priceData: Partial<ModelPriceData>,
  60. source: "litellm" | "manual" = "manual"
  61. ): ModelPrice {
  62. const now = new Date();
  63. return {
  64. id: Math.floor(Math.random() * 1000),
  65. modelName,
  66. priceData: {
  67. mode: "chat",
  68. input_cost_per_token: 0.000001,
  69. output_cost_per_token: 0.000002,
  70. ...priceData,
  71. },
  72. source,
  73. createdAt: now,
  74. updatedAt: now,
  75. };
  76. }
  77. describe("Model Price Actions", () => {
  78. beforeEach(() => {
  79. vi.clearAllMocks();
  80. // Default: admin session
  81. getSessionMock.mockResolvedValue({ user: { id: 1, role: "admin" } });
  82. findAllLatestPricesMock.mockResolvedValue([]);
  83. });
  84. describe("upsertSingleModelPrice", () => {
  85. it("should create a new model price for admin", async () => {
  86. const mockResult = makeMockPrice("gpt-5.2-codex", {
  87. mode: "chat",
  88. input_cost_per_token: 0.000015,
  89. output_cost_per_token: 0.00006,
  90. });
  91. upsertModelPriceMock.mockResolvedValue(mockResult);
  92. const { upsertSingleModelPrice } = await import("@/actions/model-prices");
  93. const result = await upsertSingleModelPrice({
  94. modelName: "gpt-5.2-codex",
  95. mode: "chat",
  96. litellmProvider: "openai",
  97. inputCostPerToken: 0.000015,
  98. outputCostPerToken: 0.00006,
  99. });
  100. expect(result.ok).toBe(true);
  101. expect(result.data?.modelName).toBe("gpt-5.2-codex");
  102. expect(upsertModelPriceMock).toHaveBeenCalledWith(
  103. "gpt-5.2-codex",
  104. expect.objectContaining({
  105. mode: "chat",
  106. litellm_provider: "openai",
  107. input_cost_per_token: 0.000015,
  108. output_cost_per_token: 0.00006,
  109. })
  110. );
  111. });
  112. it("should reject empty model name", async () => {
  113. const { upsertSingleModelPrice } = await import("@/actions/model-prices");
  114. const result = await upsertSingleModelPrice({
  115. modelName: " ",
  116. mode: "chat",
  117. });
  118. expect(result.ok).toBe(false);
  119. expect(result.error).toContain("模型名称");
  120. expect(upsertModelPriceMock).not.toHaveBeenCalled();
  121. });
  122. it("should reject non-admin users", async () => {
  123. getSessionMock.mockResolvedValue({ user: { id: 2, role: "user" } });
  124. const { upsertSingleModelPrice } = await import("@/actions/model-prices");
  125. const result = await upsertSingleModelPrice({
  126. modelName: "test-model",
  127. mode: "chat",
  128. });
  129. expect(result.ok).toBe(false);
  130. expect(result.error).toContain("无权限");
  131. expect(upsertModelPriceMock).not.toHaveBeenCalled();
  132. });
  133. it("should handle image generation mode", async () => {
  134. const mockResult = makeMockPrice("dall-e-3", {
  135. mode: "image_generation",
  136. output_cost_per_image: 0.04,
  137. });
  138. upsertModelPriceMock.mockResolvedValue(mockResult);
  139. const { upsertSingleModelPrice } = await import("@/actions/model-prices");
  140. const result = await upsertSingleModelPrice({
  141. modelName: "dall-e-3",
  142. mode: "image_generation",
  143. litellmProvider: "openai",
  144. outputCostPerImage: 0.04,
  145. });
  146. expect(result.ok).toBe(true);
  147. expect(upsertModelPriceMock).toHaveBeenCalledWith(
  148. "dall-e-3",
  149. expect.objectContaining({
  150. mode: "image_generation",
  151. output_cost_per_image: 0.04,
  152. })
  153. );
  154. });
  155. it("should handle repository errors gracefully", async () => {
  156. upsertModelPriceMock.mockRejectedValue(new Error("Database error"));
  157. const { upsertSingleModelPrice } = await import("@/actions/model-prices");
  158. const result = await upsertSingleModelPrice({
  159. modelName: "test-model",
  160. mode: "chat",
  161. });
  162. expect(result.ok).toBe(false);
  163. expect(result.error).toBeDefined();
  164. });
  165. });
  166. describe("deleteSingleModelPrice", () => {
  167. it("should delete a model price for admin", async () => {
  168. deleteModelPriceByNameMock.mockResolvedValue(undefined);
  169. const { deleteSingleModelPrice } = await import("@/actions/model-prices");
  170. const result = await deleteSingleModelPrice("gpt-5.2-codex");
  171. expect(result.ok).toBe(true);
  172. expect(deleteModelPriceByNameMock).toHaveBeenCalledWith("gpt-5.2-codex");
  173. });
  174. it("should reject empty model name", async () => {
  175. const { deleteSingleModelPrice } = await import("@/actions/model-prices");
  176. const result = await deleteSingleModelPrice("");
  177. expect(result.ok).toBe(false);
  178. expect(result.error).toContain("模型名称");
  179. expect(deleteModelPriceByNameMock).not.toHaveBeenCalled();
  180. });
  181. it("should reject non-admin users", async () => {
  182. getSessionMock.mockResolvedValue({ user: { id: 2, role: "user" } });
  183. const { deleteSingleModelPrice } = await import("@/actions/model-prices");
  184. const result = await deleteSingleModelPrice("test-model");
  185. expect(result.ok).toBe(false);
  186. expect(result.error).toContain("无权限");
  187. expect(deleteModelPriceByNameMock).not.toHaveBeenCalled();
  188. });
  189. it("should handle repository errors gracefully", async () => {
  190. deleteModelPriceByNameMock.mockRejectedValue(new Error("Database error"));
  191. const { deleteSingleModelPrice } = await import("@/actions/model-prices");
  192. const result = await deleteSingleModelPrice("test-model");
  193. expect(result.ok).toBe(false);
  194. expect(result.error).toBeDefined();
  195. });
  196. });
  197. describe("checkLiteLLMSyncConflicts", () => {
  198. it("should return no conflicts when no manual prices exist", async () => {
  199. findAllManualPricesMock.mockResolvedValue(new Map());
  200. fetchCloudPriceTableTomlMock.mockResolvedValue({
  201. ok: true,
  202. data: ['[models."claude-3-opus"]', 'mode = "chat"', "input_cost_per_token = 0.000015"].join(
  203. "\n"
  204. ),
  205. });
  206. const { checkLiteLLMSyncConflicts } = await import("@/actions/model-prices");
  207. const result = await checkLiteLLMSyncConflicts();
  208. expect(result.ok).toBe(true);
  209. expect(result.data?.hasConflicts).toBe(false);
  210. expect(result.data?.conflicts).toHaveLength(0);
  211. });
  212. it("should detect conflicts when manual prices exist in LiteLLM", async () => {
  213. const manualPrice = makeMockPrice("claude-3-opus", {
  214. mode: "chat",
  215. input_cost_per_token: 0.00001,
  216. output_cost_per_token: 0.00002,
  217. });
  218. findAllManualPricesMock.mockResolvedValue(new Map([["claude-3-opus", manualPrice]]));
  219. fetchCloudPriceTableTomlMock.mockResolvedValue({
  220. ok: true,
  221. data: [
  222. '[models."claude-3-opus"]',
  223. 'mode = "chat"',
  224. "input_cost_per_token = 0.000015",
  225. "output_cost_per_token = 0.00006",
  226. ].join("\n"),
  227. });
  228. const { checkLiteLLMSyncConflicts } = await import("@/actions/model-prices");
  229. const result = await checkLiteLLMSyncConflicts();
  230. expect(result.ok).toBe(true);
  231. expect(result.data?.hasConflicts).toBe(true);
  232. expect(result.data?.conflicts).toHaveLength(1);
  233. expect(result.data?.conflicts[0]?.modelName).toBe("claude-3-opus");
  234. });
  235. it("should not report conflicts for manual prices not in LiteLLM", async () => {
  236. const manualPrice = makeMockPrice("custom-model", {
  237. mode: "chat",
  238. input_cost_per_token: 0.00001,
  239. });
  240. findAllManualPricesMock.mockResolvedValue(new Map([["custom-model", manualPrice]]));
  241. fetchCloudPriceTableTomlMock.mockResolvedValue({
  242. ok: true,
  243. data: ['[models."claude-3-opus"]', 'mode = "chat"', "input_cost_per_token = 0.000015"].join(
  244. "\n"
  245. ),
  246. });
  247. const { checkLiteLLMSyncConflicts } = await import("@/actions/model-prices");
  248. const result = await checkLiteLLMSyncConflicts();
  249. expect(result.ok).toBe(true);
  250. expect(result.data?.hasConflicts).toBe(false);
  251. expect(result.data?.conflicts).toHaveLength(0);
  252. });
  253. it("should reject non-admin users", async () => {
  254. getSessionMock.mockResolvedValue({ user: { id: 2, role: "user" } });
  255. const { checkLiteLLMSyncConflicts } = await import("@/actions/model-prices");
  256. const result = await checkLiteLLMSyncConflicts();
  257. expect(result.ok).toBe(false);
  258. expect(result.error).toContain("无权限");
  259. });
  260. it("should handle network errors gracefully", async () => {
  261. findAllManualPricesMock.mockResolvedValue(new Map());
  262. fetchCloudPriceTableTomlMock.mockResolvedValue({
  263. ok: false,
  264. error: "云端价格表拉取失败:mock",
  265. });
  266. const { checkLiteLLMSyncConflicts } = await import("@/actions/model-prices");
  267. const result = await checkLiteLLMSyncConflicts();
  268. expect(result.ok).toBe(false);
  269. expect(result.error).toContain("云端");
  270. });
  271. it("should handle invalid TOML gracefully", async () => {
  272. findAllManualPricesMock.mockResolvedValue(new Map());
  273. fetchCloudPriceTableTomlMock.mockResolvedValue({
  274. ok: true,
  275. data: "[models\ninvalid = true",
  276. });
  277. const { checkLiteLLMSyncConflicts } = await import("@/actions/model-prices");
  278. const result = await checkLiteLLMSyncConflicts();
  279. expect(result.ok).toBe(false);
  280. expect(result.error).toContain("TOML");
  281. });
  282. });
  283. describe("processPriceTableInternal - source handling", () => {
  284. it("should skip manual prices during sync by default", async () => {
  285. const manualPrice = makeMockPrice("custom-model", {
  286. mode: "chat",
  287. input_cost_per_token: 0.00001,
  288. });
  289. findAllManualPricesMock.mockResolvedValue(new Map([["custom-model", manualPrice]]));
  290. findAllLatestPricesMock.mockResolvedValue([manualPrice]);
  291. const { processPriceTableInternal } = await import("@/actions/model-prices");
  292. const result = await processPriceTableInternal(
  293. JSON.stringify({
  294. "custom-model": {
  295. mode: "chat",
  296. input_cost_per_token: 0.000015,
  297. },
  298. })
  299. );
  300. expect(result.ok).toBe(true);
  301. expect(result.data?.skippedConflicts).toContain("custom-model");
  302. expect(result.data?.unchanged).toContain("custom-model");
  303. expect(createModelPriceMock).not.toHaveBeenCalled();
  304. });
  305. it("should overwrite manual prices when specified", async () => {
  306. const manualPrice = makeMockPrice("custom-model", {
  307. mode: "chat",
  308. input_cost_per_token: 0.00001,
  309. });
  310. findAllManualPricesMock.mockResolvedValue(new Map([["custom-model", manualPrice]]));
  311. findAllLatestPricesMock.mockResolvedValue([manualPrice]);
  312. deleteModelPriceByNameMock.mockResolvedValue(undefined);
  313. createModelPriceMock.mockResolvedValue(
  314. makeMockPrice(
  315. "custom-model",
  316. {
  317. mode: "chat",
  318. input_cost_per_token: 0.000015,
  319. },
  320. "litellm"
  321. )
  322. );
  323. const { processPriceTableInternal } = await import("@/actions/model-prices");
  324. const result = await processPriceTableInternal(
  325. JSON.stringify({
  326. "custom-model": {
  327. mode: "chat",
  328. input_cost_per_token: 0.000015,
  329. },
  330. }),
  331. ["custom-model"] // Overwrite list
  332. );
  333. expect(result.ok).toBe(true);
  334. expect(result.data?.updated).toContain("custom-model");
  335. expect(deleteModelPriceByNameMock).toHaveBeenCalledWith("custom-model");
  336. expect(createModelPriceMock).toHaveBeenCalled();
  337. });
  338. it("should add new models with litellm source", async () => {
  339. findAllManualPricesMock.mockResolvedValue(new Map());
  340. findAllLatestPricesMock.mockResolvedValue([]);
  341. createModelPriceMock.mockResolvedValue(
  342. makeMockPrice(
  343. "new-model",
  344. {
  345. mode: "chat",
  346. },
  347. "litellm"
  348. )
  349. );
  350. const { processPriceTableInternal } = await import("@/actions/model-prices");
  351. const result = await processPriceTableInternal(
  352. JSON.stringify({
  353. "new-model": {
  354. mode: "chat",
  355. input_cost_per_token: 0.000001,
  356. },
  357. })
  358. );
  359. expect(result.ok).toBe(true);
  360. expect(result.data?.added).toContain("new-model");
  361. expect(createModelPriceMock).toHaveBeenCalledWith("new-model", expect.any(Object), "litellm");
  362. });
  363. it("should skip metadata fields like sample_spec", async () => {
  364. findAllManualPricesMock.mockResolvedValue(new Map());
  365. findAllLatestPricesMock.mockResolvedValue([]);
  366. const { processPriceTableInternal } = await import("@/actions/model-prices");
  367. const result = await processPriceTableInternal(
  368. JSON.stringify({
  369. sample_spec: { description: "This is metadata" },
  370. "real-model": { mode: "chat", input_cost_per_token: 0.000001 },
  371. })
  372. );
  373. expect(result.ok).toBe(true);
  374. expect(result.data?.total).toBe(1); // Only real-model
  375. expect(result.data?.failed).not.toContain("sample_spec");
  376. });
  377. it("should skip entries without mode field", async () => {
  378. findAllManualPricesMock.mockResolvedValue(new Map());
  379. findAllLatestPricesMock.mockResolvedValue([]);
  380. const { processPriceTableInternal } = await import("@/actions/model-prices");
  381. const result = await processPriceTableInternal(
  382. JSON.stringify({
  383. "invalid-model": { input_cost_per_token: 0.000001 }, // No mode
  384. "valid-model": { mode: "chat", input_cost_per_token: 0.000001 },
  385. })
  386. );
  387. expect(result.ok).toBe(true);
  388. expect(result.data?.failed).toContain("invalid-model");
  389. });
  390. it("should ignore dangerous keys when comparing price data", async () => {
  391. const existing = makeMockPrice(
  392. "safe-model",
  393. {
  394. mode: "chat",
  395. input_cost_per_token: 0.000001,
  396. output_cost_per_token: 0.000002,
  397. },
  398. "litellm"
  399. );
  400. findAllManualPricesMock.mockResolvedValue(new Map());
  401. findAllLatestPricesMock.mockResolvedValue([existing]);
  402. const { processPriceTableInternal } = await import("@/actions/model-prices");
  403. const result = await processPriceTableInternal(
  404. JSON.stringify({
  405. "safe-model": {
  406. mode: "chat",
  407. input_cost_per_token: 0.000001,
  408. output_cost_per_token: 0.000002,
  409. constructor: { prototype: { polluted: true } },
  410. },
  411. })
  412. );
  413. expect(result.ok).toBe(true);
  414. expect(result.data?.unchanged).toContain("safe-model");
  415. expect(createModelPriceMock).not.toHaveBeenCalled();
  416. });
  417. });
  418. describe("pinModelPricingProviderAsManual", () => {
  419. it("should pin a cloud provider pricing node as a local manual model price", async () => {
  420. findLatestPriceByModelAndSourceMock.mockResolvedValue(
  421. makeMockPrice(
  422. "gpt-5.4",
  423. {
  424. mode: "responses",
  425. display_name: "GPT-5.4",
  426. model_family: "gpt",
  427. pricing: {
  428. openrouter: {
  429. input_cost_per_token: 0.0000025,
  430. output_cost_per_token: 0.000015,
  431. cache_read_input_token_cost: 2.5e-7,
  432. },
  433. },
  434. },
  435. "litellm"
  436. )
  437. );
  438. upsertModelPriceMock.mockResolvedValue(
  439. makeMockPrice(
  440. "gpt-5.4",
  441. {
  442. mode: "responses",
  443. input_cost_per_token: 0.0000025,
  444. output_cost_per_token: 0.000015,
  445. cache_read_input_token_cost: 2.5e-7,
  446. selected_pricing_provider: "openrouter",
  447. },
  448. "manual"
  449. )
  450. );
  451. const { pinModelPricingProviderAsManual } = await import("@/actions/model-prices");
  452. const result = await pinModelPricingProviderAsManual({
  453. modelName: "gpt-5.4",
  454. pricingProviderKey: "openrouter",
  455. });
  456. expect(result.ok).toBe(true);
  457. expect(findLatestPriceByModelAndSourceMock).toHaveBeenCalledWith("gpt-5.4", "litellm");
  458. expect(upsertModelPriceMock).toHaveBeenCalledWith(
  459. "gpt-5.4",
  460. expect.objectContaining({
  461. mode: "responses",
  462. input_cost_per_token: 0.0000025,
  463. output_cost_per_token: 0.000015,
  464. cache_read_input_token_cost: 2.5e-7,
  465. selected_pricing_provider: "openrouter",
  466. selected_pricing_source_model: "gpt-5.4",
  467. selected_pricing_resolution: "manual_pin",
  468. })
  469. );
  470. });
  471. });
  472. });