proxy-forwarder-endpoint-audit.test.ts 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. import { beforeEach, describe, expect, test, vi } from "vitest";
  2. const mocks = vi.hoisted(() => {
  3. return {
  4. getPreferredProviderEndpoints: vi.fn(),
  5. recordEndpointSuccess: vi.fn(async () => {}),
  6. recordEndpointFailure: vi.fn(async () => {}),
  7. recordSuccess: vi.fn(),
  8. recordFailure: vi.fn(async () => {}),
  9. getCircuitState: vi.fn(() => "closed"),
  10. getProviderHealthInfo: vi.fn(async () => ({
  11. health: { failureCount: 0 },
  12. config: { failureThreshold: 3 },
  13. })),
  14. isVendorTypeCircuitOpen: vi.fn(async () => false),
  15. recordVendorTypeAllEndpointsTimeout: vi.fn(async () => {}),
  16. categorizeErrorAsync: vi.fn(),
  17. };
  18. });
  19. vi.mock("@/lib/logger", () => ({
  20. logger: {
  21. debug: vi.fn(),
  22. info: vi.fn(),
  23. warn: vi.fn(),
  24. trace: vi.fn(),
  25. error: vi.fn(),
  26. fatal: vi.fn(),
  27. },
  28. }));
  29. vi.mock("@/lib/provider-endpoints/endpoint-selector", () => ({
  30. getPreferredProviderEndpoints: mocks.getPreferredProviderEndpoints,
  31. }));
  32. vi.mock("@/lib/endpoint-circuit-breaker", () => ({
  33. recordEndpointSuccess: mocks.recordEndpointSuccess,
  34. recordEndpointFailure: mocks.recordEndpointFailure,
  35. }));
  36. vi.mock("@/lib/circuit-breaker", () => ({
  37. getCircuitState: mocks.getCircuitState,
  38. getProviderHealthInfo: mocks.getProviderHealthInfo,
  39. recordSuccess: mocks.recordSuccess,
  40. recordFailure: mocks.recordFailure,
  41. }));
  42. vi.mock("@/lib/vendor-type-circuit-breaker", () => ({
  43. isVendorTypeCircuitOpen: mocks.isVendorTypeCircuitOpen,
  44. recordVendorTypeAllEndpointsTimeout: mocks.recordVendorTypeAllEndpointsTimeout,
  45. }));
  46. vi.mock("@/app/v1/_lib/proxy/errors", async (importOriginal) => {
  47. const actual = await importOriginal<typeof import("@/app/v1/_lib/proxy/errors")>();
  48. return {
  49. ...actual,
  50. categorizeErrorAsync: mocks.categorizeErrorAsync,
  51. };
  52. });
  53. import { ProxyForwarder } from "@/app/v1/_lib/proxy/forwarder";
  54. import { ProxyError } from "@/app/v1/_lib/proxy/errors";
  55. import { ProxySession } from "@/app/v1/_lib/proxy/session";
  56. import type { Provider, ProviderEndpoint, ProviderType } from "@/types/provider";
  57. function makeEndpoint(input: {
  58. id: number;
  59. vendorId: number;
  60. providerType: ProviderType;
  61. url: string;
  62. }): ProviderEndpoint {
  63. const now = new Date("2026-01-01T00:00:00.000Z");
  64. return {
  65. id: input.id,
  66. vendorId: input.vendorId,
  67. providerType: input.providerType,
  68. url: input.url,
  69. label: null,
  70. sortOrder: 0,
  71. isEnabled: true,
  72. lastProbedAt: null,
  73. lastProbeOk: null,
  74. lastProbeStatusCode: null,
  75. lastProbeLatencyMs: null,
  76. lastProbeErrorType: null,
  77. lastProbeErrorMessage: null,
  78. createdAt: now,
  79. updatedAt: now,
  80. deletedAt: null,
  81. };
  82. }
  83. function createProvider(overrides: Partial<Provider> = {}): Provider {
  84. return {
  85. id: 1,
  86. name: "p1",
  87. url: "https://provider.example.com",
  88. key: "k",
  89. providerVendorId: 123,
  90. isEnabled: true,
  91. weight: 1,
  92. priority: 0,
  93. costMultiplier: 1,
  94. groupTag: null,
  95. providerType: "claude",
  96. preserveClientIp: false,
  97. modelRedirects: null,
  98. allowedModels: null,
  99. mcpPassthroughType: "none",
  100. mcpPassthroughUrl: null,
  101. limit5hUsd: null,
  102. limitDailyUsd: null,
  103. dailyResetMode: "fixed",
  104. dailyResetTime: "00:00",
  105. limitWeeklyUsd: null,
  106. limitMonthlyUsd: null,
  107. limitTotalUsd: null,
  108. totalCostResetAt: null,
  109. limitConcurrentSessions: 0,
  110. maxRetryAttempts: null,
  111. circuitBreakerFailureThreshold: 5,
  112. circuitBreakerOpenDuration: 1_800_000,
  113. circuitBreakerHalfOpenSuccessThreshold: 2,
  114. proxyUrl: null,
  115. proxyFallbackToDirect: false,
  116. firstByteTimeoutStreamingMs: 30_000,
  117. streamingIdleTimeoutMs: 10_000,
  118. requestTimeoutNonStreamingMs: 600_000,
  119. websiteUrl: null,
  120. faviconUrl: null,
  121. cacheTtlPreference: null,
  122. context1mPreference: null,
  123. codexReasoningEffortPreference: null,
  124. codexReasoningSummaryPreference: null,
  125. codexTextVerbosityPreference: null,
  126. codexParallelToolCallsPreference: null,
  127. tpm: 0,
  128. rpm: 0,
  129. rpd: 0,
  130. cc: 0,
  131. createdAt: new Date(),
  132. updatedAt: new Date(),
  133. deletedAt: null,
  134. ...overrides,
  135. };
  136. }
  137. function createSession(requestUrl: URL = new URL("https://example.com/v1/messages")): ProxySession {
  138. const headers = new Headers();
  139. const session = Object.create(ProxySession.prototype);
  140. Object.assign(session, {
  141. startTime: Date.now(),
  142. method: "POST",
  143. requestUrl,
  144. headers,
  145. originalHeaders: new Headers(headers),
  146. headerLog: JSON.stringify(Object.fromEntries(headers.entries())),
  147. request: {
  148. model: "model-x",
  149. log: "(test)",
  150. message: {
  151. model: "model-x",
  152. messages: [
  153. { role: "user", content: "hello" },
  154. { role: "assistant", content: "ok" },
  155. ],
  156. },
  157. },
  158. userAgent: null,
  159. context: null,
  160. clientAbortSignal: null,
  161. userName: "test-user",
  162. authState: { success: true, user: null, key: null, apiKey: null },
  163. provider: null,
  164. messageContext: null,
  165. sessionId: null,
  166. requestSequence: 1,
  167. originalFormat: "claude",
  168. providerType: null,
  169. originalModelName: null,
  170. originalUrlPathname: null,
  171. providerChain: [],
  172. cacheTtlResolved: null,
  173. context1mApplied: false,
  174. specialSettings: [],
  175. cachedPriceData: undefined,
  176. cachedBillingModelSource: undefined,
  177. isHeaderModified: () => false,
  178. });
  179. return session as ProxySession;
  180. }
  181. describe("ProxyForwarder - endpoint audit", () => {
  182. beforeEach(() => {
  183. vi.clearAllMocks();
  184. });
  185. test("成功时应记录 endpointId 且对 endpointUrl 做脱敏", async () => {
  186. const session = createSession();
  187. const provider = createProvider({ providerType: "claude", providerVendorId: 123 });
  188. session.setProvider(provider);
  189. mocks.getPreferredProviderEndpoints.mockResolvedValue([
  190. makeEndpoint({
  191. id: 42,
  192. vendorId: 123,
  193. providerType: provider.providerType,
  194. url: "https://api.example.com/v1/messages?api_key=SECRET&foo=bar",
  195. }),
  196. ]);
  197. const doForward = vi.spyOn(
  198. ProxyForwarder as unknown as { doForward: (...args: unknown[]) => unknown },
  199. "doForward"
  200. );
  201. doForward.mockResolvedValueOnce(
  202. new Response("{}", {
  203. status: 200,
  204. headers: {
  205. "content-type": "application/json",
  206. "content-length": "2",
  207. },
  208. })
  209. );
  210. const response = await ProxyForwarder.send(session);
  211. expect(response.status).toBe(200);
  212. const chain = session.getProviderChain();
  213. expect(chain).toHaveLength(1);
  214. const item = chain[0];
  215. expect(item).toEqual(
  216. expect.objectContaining({
  217. reason: "request_success",
  218. attemptNumber: 1,
  219. statusCode: 200,
  220. vendorId: 123,
  221. providerType: "claude",
  222. endpointId: 42,
  223. })
  224. );
  225. expect(item.endpointUrl).toContain("[REDACTED]");
  226. expect(item.endpointUrl).not.toContain("SECRET");
  227. });
  228. test("重试时应分别记录每次 attempt 的 endpoint 审计字段", async () => {
  229. vi.useFakeTimers();
  230. try {
  231. const session = createSession(new URL("https://example.com/v1/chat/completions"));
  232. const provider = createProvider({
  233. providerType: "openai-compatible",
  234. providerVendorId: 123,
  235. });
  236. session.setProvider(provider);
  237. mocks.getPreferredProviderEndpoints.mockResolvedValue([
  238. makeEndpoint({
  239. id: 1,
  240. vendorId: 123,
  241. providerType: provider.providerType,
  242. url: "https://api.example.com/v1?token=SECRET_1",
  243. }),
  244. makeEndpoint({
  245. id: 2,
  246. vendorId: 123,
  247. providerType: provider.providerType,
  248. url: "https://api.example.com/v1?api_key=SECRET_2",
  249. }),
  250. ]);
  251. const doForward = vi.spyOn(
  252. ProxyForwarder as unknown as { doForward: (...args: unknown[]) => unknown },
  253. "doForward"
  254. );
  255. // Throw network error (SYSTEM_ERROR) to trigger endpoint switching
  256. // PROVIDER_ERROR (HTTP 4xx/5xx) doesn't trigger endpoint switch, only SYSTEM_ERROR does
  257. doForward.mockImplementationOnce(async () => {
  258. const err = new Error("ECONNREFUSED") as NodeJS.ErrnoException;
  259. err.code = "ECONNREFUSED";
  260. throw err;
  261. });
  262. // Configure categorizeErrorAsync to return SYSTEM_ERROR for network errors
  263. mocks.categorizeErrorAsync.mockResolvedValueOnce(1); // ErrorCategory.SYSTEM_ERROR = 1
  264. doForward.mockResolvedValueOnce(
  265. new Response("{}", {
  266. status: 200,
  267. headers: {
  268. "content-type": "application/json",
  269. "content-length": "2",
  270. },
  271. })
  272. );
  273. const sendPromise = ProxyForwarder.send(session);
  274. await vi.advanceTimersByTimeAsync(100);
  275. const response = await sendPromise;
  276. expect(response.status).toBe(200);
  277. const chain = session.getProviderChain();
  278. expect(chain).toHaveLength(2);
  279. const first = chain[0];
  280. const second = chain[1];
  281. expect(first).toEqual(
  282. expect.objectContaining({
  283. reason: "system_error",
  284. attemptNumber: 1,
  285. vendorId: 123,
  286. providerType: "openai-compatible",
  287. endpointId: 1,
  288. })
  289. );
  290. expect(first.endpointUrl).toContain("[REDACTED]");
  291. expect(first.endpointUrl).not.toContain("SECRET_1");
  292. expect(second).toEqual(
  293. expect.objectContaining({
  294. reason: "retry_success",
  295. attemptNumber: 2,
  296. vendorId: 123,
  297. providerType: "openai-compatible",
  298. endpointId: 2,
  299. })
  300. );
  301. expect(second.endpointUrl).toContain("[REDACTED]");
  302. expect(second.endpointUrl).not.toContain("SECRET_2");
  303. } finally {
  304. vi.useRealTimers();
  305. }
  306. });
  307. test("endpoint 选择失败时应回退到 provider.url,并记录 endpointId=null", async () => {
  308. const session = createSession();
  309. const provider = createProvider({
  310. providerType: "claude",
  311. providerVendorId: 123,
  312. url: "https://provider.example.com/v1/messages?key=SECRET",
  313. });
  314. session.setProvider(provider);
  315. mocks.getPreferredProviderEndpoints.mockRejectedValue(new Error("boom"));
  316. const doForward = vi.spyOn(
  317. ProxyForwarder as unknown as { doForward: (...args: unknown[]) => unknown },
  318. "doForward"
  319. );
  320. doForward.mockResolvedValueOnce(
  321. new Response("{}", {
  322. status: 200,
  323. headers: {
  324. "content-type": "application/json",
  325. "content-length": "2",
  326. },
  327. })
  328. );
  329. const response = await ProxyForwarder.send(session);
  330. expect(response.status).toBe(200);
  331. const chain = session.getProviderChain();
  332. expect(chain).toHaveLength(1);
  333. const item = chain[0];
  334. expect(item).toEqual(
  335. expect.objectContaining({
  336. endpointId: null,
  337. })
  338. );
  339. expect(item.endpointUrl).toContain("[REDACTED]");
  340. expect(item.endpointUrl).not.toContain("SECRET");
  341. });
  342. });