billing-model-source.test.ts 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446
  1. import { beforeEach, describe, expect, it, vi } from "vitest";
  2. import type { ModelPrice, ModelPriceData } from "@/types/model-price";
  3. import type { SystemSettings } from "@/types/system-config";
  4. const asyncTasks: Promise<void>[] = [];
  5. const cloudPriceSyncRequests: Array<{ reason: string }> = [];
  6. vi.mock("@/lib/async-task-manager", () => ({
  7. AsyncTaskManager: {
  8. register: (_taskId: string, promise: Promise<void>) => {
  9. asyncTasks.push(promise);
  10. return new AbortController();
  11. },
  12. cleanup: () => {},
  13. cancel: () => {},
  14. },
  15. }));
  16. vi.mock("@/lib/logger", () => ({
  17. logger: {
  18. debug: () => {},
  19. info: () => {},
  20. warn: () => {},
  21. error: () => {},
  22. trace: () => {},
  23. },
  24. }));
  25. vi.mock("@/lib/price-sync/cloud-price-updater", () => ({
  26. requestCloudPriceTableSync: (payload: { reason: string }) => {
  27. cloudPriceSyncRequests.push(payload);
  28. },
  29. }));
  30. vi.mock("@/repository/model-price", () => ({
  31. findLatestPriceByModel: vi.fn(),
  32. }));
  33. vi.mock("@/repository/system-config", () => ({
  34. getSystemSettings: vi.fn(),
  35. }));
  36. vi.mock("@/repository/message", () => ({
  37. updateMessageRequestCost: vi.fn(),
  38. updateMessageRequestDetails: vi.fn(),
  39. updateMessageRequestDuration: vi.fn(),
  40. }));
  41. vi.mock("@/lib/session-manager", () => ({
  42. SessionManager: {
  43. updateSessionUsage: vi.fn(),
  44. storeSessionResponse: vi.fn(),
  45. extractCodexPromptCacheKey: vi.fn(),
  46. updateSessionWithCodexCacheKey: vi.fn(),
  47. },
  48. }));
  49. vi.mock("@/lib/rate-limit", () => ({
  50. RateLimitService: {
  51. trackCost: vi.fn(),
  52. trackUserDailyCost: vi.fn(),
  53. },
  54. }));
  55. vi.mock("@/lib/session-tracker", () => ({
  56. SessionTracker: {
  57. refreshSession: vi.fn(),
  58. },
  59. }));
  60. vi.mock("@/lib/proxy-status-tracker", () => ({
  61. ProxyStatusTracker: {
  62. getInstance: () => ({
  63. endRequest: () => {},
  64. }),
  65. },
  66. }));
  67. import { ProxyResponseHandler } from "@/app/v1/_lib/proxy/response-handler";
  68. import { ProxySession } from "@/app/v1/_lib/proxy/session";
  69. import { SessionManager } from "@/lib/session-manager";
  70. import { RateLimitService } from "@/lib/rate-limit";
  71. import { SessionTracker } from "@/lib/session-tracker";
  72. import {
  73. updateMessageRequestCost,
  74. updateMessageRequestDetails,
  75. updateMessageRequestDuration,
  76. } from "@/repository/message";
  77. import { findLatestPriceByModel } from "@/repository/model-price";
  78. import { getSystemSettings } from "@/repository/system-config";
  79. beforeEach(() => {
  80. cloudPriceSyncRequests.splice(0, cloudPriceSyncRequests.length);
  81. });
  82. function makeSystemSettings(
  83. billingModelSource: SystemSettings["billingModelSource"]
  84. ): SystemSettings {
  85. const now = new Date();
  86. return {
  87. id: 1,
  88. siteTitle: "test",
  89. allowGlobalUsageView: false,
  90. currencyDisplay: "USD",
  91. billingModelSource,
  92. timezone: null,
  93. enableAutoCleanup: false,
  94. cleanupRetentionDays: 30,
  95. cleanupSchedule: "0 2 * * *",
  96. cleanupBatchSize: 10000,
  97. enableClientVersionCheck: false,
  98. verboseProviderError: false,
  99. enableHttp2: false,
  100. interceptAnthropicWarmupRequests: false,
  101. enableResponseFixer: true,
  102. responseFixerConfig: {
  103. fixTruncatedJson: true,
  104. fixSseFormat: true,
  105. fixEncoding: true,
  106. maxJsonDepth: 200,
  107. maxFixSize: 1024 * 1024,
  108. },
  109. createdAt: now,
  110. updatedAt: now,
  111. };
  112. }
  113. function makePriceRecord(modelName: string, priceData: ModelPriceData): ModelPrice {
  114. const now = new Date();
  115. return {
  116. id: 1,
  117. modelName,
  118. priceData,
  119. createdAt: now,
  120. updatedAt: now,
  121. };
  122. }
  123. function createSession({
  124. originalModel,
  125. redirectedModel,
  126. sessionId,
  127. messageId,
  128. }: {
  129. originalModel: string;
  130. redirectedModel: string;
  131. sessionId: string;
  132. messageId: number;
  133. }): ProxySession {
  134. const session = new (
  135. ProxySession as unknown as {
  136. new (init: {
  137. startTime: number;
  138. method: string;
  139. requestUrl: URL;
  140. headers: Headers;
  141. headerLog: string;
  142. request: { message: Record<string, unknown>; log: string; model: string | null };
  143. userAgent: string | null;
  144. context: unknown;
  145. clientAbortSignal: AbortSignal | null;
  146. }): ProxySession;
  147. }
  148. )({
  149. startTime: Date.now(),
  150. method: "POST",
  151. requestUrl: new URL("http://localhost/v1/messages"),
  152. headers: new Headers(),
  153. headerLog: "",
  154. request: { message: {}, log: "(test)", model: redirectedModel },
  155. userAgent: null,
  156. context: {},
  157. clientAbortSignal: null,
  158. });
  159. session.setOriginalModel(originalModel);
  160. session.setSessionId(sessionId);
  161. const provider = {
  162. id: 99,
  163. name: "test-provider",
  164. providerType: "claude",
  165. costMultiplier: 1.0,
  166. streamingIdleTimeoutMs: 0,
  167. } as any;
  168. const user = {
  169. id: 123,
  170. name: "test-user",
  171. dailyResetTime: "00:00",
  172. dailyResetMode: "fixed",
  173. } as any;
  174. const key = {
  175. id: 456,
  176. name: "test-key",
  177. dailyResetTime: "00:00",
  178. dailyResetMode: "fixed",
  179. } as any;
  180. session.setProvider(provider);
  181. session.setAuthState({
  182. user,
  183. key,
  184. apiKey: "sk-test",
  185. success: true,
  186. });
  187. session.setMessageContext({
  188. id: messageId,
  189. createdAt: new Date(),
  190. user,
  191. key,
  192. apiKey: "sk-test",
  193. });
  194. return session;
  195. }
  196. function createNonStreamResponse(usage: { input_tokens: number; output_tokens: number }): Response {
  197. return new Response(
  198. JSON.stringify({
  199. type: "message",
  200. usage,
  201. }),
  202. {
  203. status: 200,
  204. headers: { "content-type": "application/json" },
  205. }
  206. );
  207. }
  208. function createStreamResponse(usage: { input_tokens: number; output_tokens: number }): Response {
  209. const sseText = `event: message_delta\ndata: ${JSON.stringify({ usage })}\n\n`;
  210. const encoder = new TextEncoder();
  211. const stream = new ReadableStream<Uint8Array>({
  212. start(controller) {
  213. controller.enqueue(encoder.encode(sseText));
  214. controller.close();
  215. },
  216. });
  217. return new Response(stream, {
  218. status: 200,
  219. headers: { "content-type": "text/event-stream" },
  220. });
  221. }
  222. async function drainAsyncTasks(): Promise<void> {
  223. const tasks = asyncTasks.splice(0, asyncTasks.length);
  224. await Promise.all(tasks);
  225. }
  226. async function runScenario({
  227. billingModelSource,
  228. isStream,
  229. }: {
  230. billingModelSource: SystemSettings["billingModelSource"];
  231. isStream: boolean;
  232. }): Promise<{ dbCostUsd: string; sessionCostUsd: string; rateLimitCost: number }> {
  233. const usage = { input_tokens: 2, output_tokens: 3 };
  234. const originalModel = "original-model";
  235. const redirectedModel = "redirected-model";
  236. const originalPriceData: ModelPriceData = { input_cost_per_token: 1, output_cost_per_token: 1 };
  237. const redirectedPriceData: ModelPriceData = {
  238. input_cost_per_token: 10,
  239. output_cost_per_token: 10,
  240. };
  241. vi.mocked(getSystemSettings).mockResolvedValue(makeSystemSettings(billingModelSource));
  242. vi.mocked(findLatestPriceByModel).mockImplementation(async (modelName: string) => {
  243. if (modelName === originalModel) {
  244. return makePriceRecord(modelName, originalPriceData);
  245. }
  246. if (modelName === redirectedModel) {
  247. return makePriceRecord(modelName, redirectedPriceData);
  248. }
  249. return null;
  250. });
  251. vi.mocked(updateMessageRequestDetails).mockResolvedValue(undefined);
  252. vi.mocked(updateMessageRequestDuration).mockResolvedValue(undefined);
  253. vi.mocked(SessionManager.storeSessionResponse).mockResolvedValue(undefined);
  254. vi.mocked(RateLimitService.trackUserDailyCost).mockResolvedValue(undefined);
  255. vi.mocked(SessionTracker.refreshSession).mockResolvedValue(undefined);
  256. const dbCosts: string[] = [];
  257. vi.mocked(updateMessageRequestCost).mockImplementation(async (_id: number, costUsd: unknown) => {
  258. dbCosts.push(String(costUsd));
  259. });
  260. const sessionCosts: string[] = [];
  261. vi.mocked(SessionManager.updateSessionUsage).mockImplementation(
  262. async (_sessionId: string, payload: Record<string, unknown>) => {
  263. if (typeof payload.costUsd === "string") {
  264. sessionCosts.push(payload.costUsd);
  265. }
  266. }
  267. );
  268. const rateLimitCosts: number[] = [];
  269. vi.mocked(RateLimitService.trackCost).mockImplementation(
  270. async (_keyId: number, _providerId: number, _sessionId: string, costUsd: number) => {
  271. rateLimitCosts.push(costUsd);
  272. }
  273. );
  274. const session = createSession({
  275. originalModel,
  276. redirectedModel,
  277. sessionId: `sess-${billingModelSource}-${isStream ? "s" : "n"}`,
  278. messageId: isStream ? 2001 : 2000,
  279. });
  280. const response = isStream ? createStreamResponse(usage) : createNonStreamResponse(usage);
  281. const clientResponse = await ProxyResponseHandler.dispatch(session, response);
  282. if (isStream) {
  283. await clientResponse.text();
  284. }
  285. await drainAsyncTasks();
  286. const dbCostUsd = dbCosts[0] ?? "";
  287. const sessionCostUsd = sessionCosts[0] ?? "";
  288. const rateLimitCost = rateLimitCosts[0] ?? Number.NaN;
  289. return { dbCostUsd, sessionCostUsd, rateLimitCost };
  290. }
  291. describe("Billing model source - Redis session cost vs DB cost", () => {
  292. it("非流式响应:配置 = original 时 Session 成本与数据库一致", async () => {
  293. const result = await runScenario({ billingModelSource: "original", isStream: false });
  294. expect(result.dbCostUsd).toBe("5");
  295. expect(result.sessionCostUsd).toBe("5");
  296. expect(result.rateLimitCost).toBe(5);
  297. });
  298. it("非流式响应:配置 = redirected 时 Session 成本与数据库一致", async () => {
  299. const result = await runScenario({ billingModelSource: "redirected", isStream: false });
  300. expect(result.dbCostUsd).toBe("50");
  301. expect(result.sessionCostUsd).toBe("50");
  302. expect(result.rateLimitCost).toBe(50);
  303. });
  304. it("流式响应:配置 = original 时 Session 成本与数据库一致", async () => {
  305. const result = await runScenario({ billingModelSource: "original", isStream: true });
  306. expect(result.dbCostUsd).toBe("5");
  307. expect(result.sessionCostUsd).toBe("5");
  308. expect(result.rateLimitCost).toBe(5);
  309. });
  310. it("流式响应:配置 = redirected 时 Session 成本与数据库一致", async () => {
  311. const result = await runScenario({ billingModelSource: "redirected", isStream: true });
  312. expect(result.dbCostUsd).toBe("50");
  313. expect(result.sessionCostUsd).toBe("50");
  314. expect(result.rateLimitCost).toBe(50);
  315. });
  316. it("从 original 切换到 redirected 后应生效", async () => {
  317. const original = await runScenario({ billingModelSource: "original", isStream: false });
  318. const redirected = await runScenario({ billingModelSource: "redirected", isStream: false });
  319. expect(original.sessionCostUsd).toBe("5");
  320. expect(redirected.sessionCostUsd).toBe("50");
  321. expect(original.sessionCostUsd).not.toBe(redirected.sessionCostUsd);
  322. });
  323. });
  324. describe("价格表缺失/查询失败:不计费放行", () => {
  325. async function runNoPriceScenario(options: {
  326. billingModelSource: SystemSettings["billingModelSource"];
  327. isStream: boolean;
  328. priceLookup: "none" | "throws";
  329. }): Promise<{ dbCostCalls: number; rateLimitCalls: number }> {
  330. const usage = { input_tokens: 2, output_tokens: 3 };
  331. const originalModel = "original-model";
  332. const redirectedModel = "redirected-model";
  333. vi.mocked(getSystemSettings).mockResolvedValue(makeSystemSettings(options.billingModelSource));
  334. if (options.priceLookup === "none") {
  335. vi.mocked(findLatestPriceByModel).mockResolvedValue(null);
  336. } else {
  337. vi.mocked(findLatestPriceByModel).mockImplementation(async () => {
  338. throw new Error("db query failed");
  339. });
  340. }
  341. vi.mocked(updateMessageRequestDetails).mockResolvedValue(undefined);
  342. vi.mocked(updateMessageRequestDuration).mockResolvedValue(undefined);
  343. vi.mocked(SessionManager.storeSessionResponse).mockResolvedValue(undefined);
  344. vi.mocked(RateLimitService.trackUserDailyCost).mockResolvedValue(undefined);
  345. vi.mocked(SessionTracker.refreshSession).mockResolvedValue(undefined);
  346. vi.mocked(updateMessageRequestCost).mockResolvedValue(undefined);
  347. vi.mocked(RateLimitService.trackCost).mockResolvedValue(undefined);
  348. vi.mocked(SessionManager.updateSessionUsage).mockResolvedValue(undefined);
  349. const session = createSession({
  350. originalModel,
  351. redirectedModel,
  352. sessionId: `sess-no-price-${options.billingModelSource}-${options.isStream ? "s" : "n"}`,
  353. messageId: options.isStream ? 3001 : 3000,
  354. });
  355. const response = options.isStream
  356. ? createStreamResponse(usage)
  357. : createNonStreamResponse(usage);
  358. const clientResponse = await ProxyResponseHandler.dispatch(session, response);
  359. await clientResponse.text();
  360. await drainAsyncTasks();
  361. return {
  362. dbCostCalls: vi.mocked(updateMessageRequestCost).mock.calls.length,
  363. rateLimitCalls: vi.mocked(RateLimitService.trackCost).mock.calls.length,
  364. };
  365. }
  366. it("无价格:不写入 DB cost,不追踪限流 cost,并触发一次异步同步", async () => {
  367. const result = await runNoPriceScenario({
  368. billingModelSource: "redirected",
  369. isStream: false,
  370. priceLookup: "none",
  371. });
  372. expect(result.dbCostCalls).toBe(0);
  373. expect(result.rateLimitCalls).toBe(0);
  374. expect(cloudPriceSyncRequests).toEqual([{ reason: "missing-model" }]);
  375. });
  376. it("价格查询抛错:不应影响响应,不写入 DB cost,不追踪限流 cost", async () => {
  377. const result = await runNoPriceScenario({
  378. billingModelSource: "original",
  379. isStream: true,
  380. priceLookup: "throws",
  381. });
  382. expect(result.dbCostCalls).toBe(0);
  383. expect(result.rateLimitCalls).toBe(0);
  384. });
  385. });