Przeglądaj źródła

feat: support gemini and openai embeddings passthrough (#931)

Ding 4 tygodni temu
rodzic
commit
2df6696cb0

+ 3 - 0
src/app/v1/_lib/proxy/endpoint-paths.ts

@@ -6,6 +6,7 @@ export const V1_ENDPOINT_PATHS = {
   RESPONSES: "/v1/responses",
   RESPONSES_COMPACT: "/v1/responses/compact",
   CHAT_COMPLETIONS: "/v1/chat/completions",
+  EMBEDDINGS: "/v1/embeddings",
   MODELS: "/v1/models",
 } as const;
 
@@ -15,6 +16,7 @@ export const STANDARD_ENDPOINT_PATHS = [
   V1_ENDPOINT_PATHS.RESPONSES,
   V1_ENDPOINT_PATHS.RESPONSES_COMPACT,
   V1_ENDPOINT_PATHS.CHAT_COMPLETIONS,
+  V1_ENDPOINT_PATHS.EMBEDDINGS,
   V1_ENDPOINT_PATHS.MODELS,
 ] as const;
 
@@ -23,6 +25,7 @@ export const STRICT_STANDARD_ENDPOINT_PATHS = [
   V1_ENDPOINT_PATHS.RESPONSES,
   V1_ENDPOINT_PATHS.RESPONSES_COMPACT,
   V1_ENDPOINT_PATHS.CHAT_COMPLETIONS,
+  V1_ENDPOINT_PATHS.EMBEDDINGS,
 ] as const;
 
 const standardEndpointPathSet = new Set<string>(STANDARD_ENDPOINT_PATHS);

+ 7 - 6
src/app/v1/_lib/proxy/format-mapper.ts

@@ -31,8 +31,8 @@ export type ClientFormat = "response" | "openai" | "claude" | "gemini" | "gemini
  * 支持的端点模式:
  * - Claude Messages API: `/v1/messages`, `/v1/messages/count_tokens`
  * - Codex Response API: `/v1/responses`
- * - OpenAI Compatible: `/v1/chat/completions`
- * - Gemini Direct: `/v1beta/models/{model}:generateContent`
+ * - OpenAI Compatible: `/v1/chat/completions`, `/v1/embeddings`
+ * - Gemini Direct: `/v1beta/models/{model}:generateContent`, `/v1beta/models/{model}:embedContent`
  * - Gemini CLI: `/v1internal/models/{model}:generateContent`
  *
  * @param pathname - URL 路径(如 `/v1/messages`)
@@ -58,19 +58,20 @@ export function detectFormatByEndpoint(pathname: string): ClientFormat | null {
     // Codex / Response API
     { pattern: /^\/v1\/responses$/i, format: "response" },
 
-    // OpenAI Chat Completions
-    { pattern: /^\/v1\/chat\/completions$/i, format: "openai" },
+    // OpenAI Chat Completions / Embeddings
+    { pattern: /^\/v1\/(?:chat\/completions|embeddings)$/i, format: "openai" },
 
     // Gemini Vertex AI (publishers path)
     {
       pattern:
-        /^\/v1\/publishers\/google\/models\/[^/:]+:(?:generateContent|streamGenerateContent|countTokens)$/i,
+        /^\/v1\/publishers\/google\/models\/[^/:]+:(?:generateContent|streamGenerateContent|countTokens|embedContent)$/i,
       format: "gemini",
     },
 
     // Gemini Direct API
     {
-      pattern: /^\/v1beta\/models\/[^/:]+:(?:generateContent|streamGenerateContent|countTokens)$/i,
+      pattern:
+        /^\/v1beta\/models\/[^/:]+:(?:generateContent|streamGenerateContent|countTokens|embedContent)$/i,
       format: "gemini",
     },
 

+ 7 - 1
src/app/v1/_lib/proxy/forwarder.ts

@@ -79,10 +79,16 @@ const STANDARD_ENDPOINTS = [
   "/v1/responses",
   "/v1/responses/compact",
   "/v1/chat/completions",
+  "/v1/embeddings",
   "/v1/models",
 ];
 
-const STRICT_STANDARD_ENDPOINTS = ["/v1/messages", "/v1/responses", "/v1/chat/completions"];
+const STRICT_STANDARD_ENDPOINTS = [
+  "/v1/messages",
+  "/v1/responses",
+  "/v1/chat/completions",
+  "/v1/embeddings",
+];
 
 const OUTBOUND_TRANSPORT_HEADER_BLACKLIST = ["content-length", "connection", "transfer-encoding"];
 

+ 1 - 0
src/app/v1/_lib/url.ts

@@ -6,6 +6,7 @@ const targetEndpoints = [
   "/responses", // Codex Response API
   "/messages", // Claude Messages API
   "/chat/completions", // OpenAI Compatible
+  "/embeddings", // OpenAI Compatible Embeddings
   "/models", // Gemini & OpenAI models
 ] as const;
 

+ 9 - 0
tests/unit/app/v1/url.test.ts

@@ -28,6 +28,15 @@ describe("buildProxyUrl", () => {
     expect(out).toBe("https://example.com/openai/responses?x=1");
   });
 
+  test("避免重复拼接:baseUrl 已包含 /embeddings 时不追加 /v1/embeddings", () => {
+    const out = buildProxyUrl(
+      "https://example.com/openai/embeddings",
+      new URL("https://dummy.com/v1/embeddings?x=1")
+    );
+
+    expect(out).toBe("https://example.com/openai/embeddings?x=1");
+  });
+
   test("子路径不丢失:baseUrl=/v1/messages + request=/v1/messages/count_tokens", () => {
     const out = buildProxyUrl(
       "https://api.example.com/v1/messages",

+ 22 - 0
tests/unit/proxy/gemini-embedcontent-format.test.ts

@@ -0,0 +1,22 @@
+import { describe, expect, it } from "vitest";
+import { detectFormatByEndpoint } from "@/app/v1/_lib/proxy/format-mapper";
+
+describe("detectFormatByEndpoint - Gemini embedContent", () => {
+  it.each([
+    "/v1beta/models/gemini-2.5-flash:embedContent",
+    "/v1/publishers/google/models/gemini-2.5-pro:embedContent",
+  ])('returns "gemini" for %s', (pathname) => {
+    expect(detectFormatByEndpoint(pathname)).toBe("gemini");
+  });
+
+  it.each([
+    "/v1beta/models/gemini-2.5-flash:unknownAction",
+    "/v1/publishers/google/models/gemini-2.5-pro:unknownAction",
+  ])("returns null for unknown Gemini actions: %s", (pathname) => {
+    expect(detectFormatByEndpoint(pathname)).toBeNull();
+  });
+
+  it("does not classify internal embedContent as gemini-cli", () => {
+    expect(detectFormatByEndpoint("/v1internal/models/gemini-2.5-flash:embedContent")).toBeNull();
+  });
+});

+ 25 - 0
tests/unit/proxy/openai-embeddings-format.test.ts

@@ -0,0 +1,25 @@
+import { describe, expect, it } from "vitest";
+import { detectFormatByEndpoint } from "@/app/v1/_lib/proxy/format-mapper";
+import { isRawPassthroughEndpointPath } from "@/app/v1/_lib/proxy/endpoint-policy";
+import {
+  isStandardEndpointPath,
+  isStrictStandardEndpointPath,
+} from "@/app/v1/_lib/proxy/endpoint-paths";
+
+describe("detectFormatByEndpoint - OpenAI embeddings", () => {
+  it('returns "openai" for /v1/embeddings', () => {
+    expect(detectFormatByEndpoint("/v1/embeddings")).toBe("openai");
+  });
+
+  it("classifies /v1/embeddings as a standard endpoint", () => {
+    expect(isStandardEndpointPath("/v1/embeddings")).toBe(true);
+  });
+
+  it("classifies /v1/embeddings as a strict standard endpoint", () => {
+    expect(isStrictStandardEndpointPath("/v1/embeddings")).toBe(true);
+  });
+
+  it("does not classify /v1/embeddings as raw passthrough", () => {
+    expect(isRawPassthroughEndpointPath("/v1/embeddings")).toBe(false);
+  });
+});

+ 134 - 0
tests/unit/proxy/openai-embeddings-forwarder.test.ts

@@ -0,0 +1,134 @@
+import { beforeEach, describe, expect, it, vi } from "vitest";
+import { resolveEndpointPolicy } from "@/app/v1/_lib/proxy/endpoint-policy";
+import { ProxyForwarder } from "@/app/v1/_lib/proxy/forwarder";
+import { ProxySession } from "@/app/v1/_lib/proxy/session";
+import type { Provider } from "@/types/provider";
+
+vi.mock("@/lib/logger", () => ({
+  logger: {
+    debug: vi.fn(),
+    error: vi.fn(),
+    info: vi.fn(),
+    trace: vi.fn(),
+    warn: vi.fn(),
+    fatal: vi.fn(),
+  },
+}));
+
+vi.mock("@/lib/request-filter-engine", () => ({
+  requestFilterEngine: {
+    applyFinal: vi.fn(async () => {}),
+  },
+}));
+
+function createProvider(): Provider {
+  return {
+    id: 1,
+    name: "openai-upstream",
+    providerType: "openai-compatible",
+    url: "https://openai.example.com/openai",
+    key: "upstream-key",
+    preserveClientIp: false,
+    priority: 0,
+    costMultiplier: 1,
+    maxRetryAttempts: 1,
+    mcpPassthroughType: "minimax",
+    mcpPassthroughUrl: "https://mcp.example.com",
+  } as unknown as Provider;
+}
+
+function createSession(): ProxySession {
+  const headers = new Headers({
+    "content-type": "application/json",
+    authorization: "Bearer proxy-user-key",
+  });
+  const session = Object.create(ProxySession.prototype);
+
+  Object.assign(session, {
+    startTime: Date.now(),
+    method: "POST",
+    requestUrl: new URL("https://proxy.example.com/v1/embeddings"),
+    headers,
+    originalHeaders: new Headers(headers),
+    headerLog: JSON.stringify(Object.fromEntries(headers.entries())),
+    request: {
+      model: "text-embedding-3-large",
+      log: JSON.stringify({
+        model: "text-embedding-3-large",
+        input: "embedding me",
+      }),
+      message: {
+        model: "text-embedding-3-large",
+        input: "embedding me",
+      },
+    },
+    userAgent: "OpenAITest/1.0",
+    context: null,
+    clientAbortSignal: null,
+    userName: "test-user",
+    authState: { success: true, user: null, key: null, apiKey: null },
+    provider: null,
+    messageContext: null,
+    sessionId: null,
+    requestSequence: 1,
+    originalFormat: "openai",
+    providerType: null,
+    originalModelName: null,
+    originalUrlPathname: null,
+    providerChain: [],
+    cacheTtlResolved: null,
+    context1mApplied: false,
+    cachedPriceData: undefined,
+    cachedBillingModelSource: undefined,
+    forwardedRequestBody: null,
+    endpointPolicy: resolveEndpointPolicy("/v1/embeddings"),
+    setCacheTtlResolved: vi.fn(),
+    getCacheTtlResolved: vi.fn(() => null),
+    getCurrentModel: vi.fn(() => "text-embedding-3-large"),
+    clientRequestsContext1m: vi.fn(() => false),
+    setContext1mApplied: vi.fn(),
+    getContext1mApplied: vi.fn(() => false),
+    getEndpointPolicy: vi.fn(() => resolveEndpointPolicy("/v1/embeddings")),
+    isHeaderModified: vi.fn(() => false),
+  });
+
+  return session as ProxySession;
+}
+
+describe("ProxyForwarder - OpenAI embeddings standard endpoint handling", () => {
+  beforeEach(() => {
+    vi.clearAllMocks();
+  });
+
+  it("does not route /v1/embeddings through MCP passthrough URL", async () => {
+    const provider = createProvider();
+    const session = createSession();
+    let capturedUrl: string | null = null;
+
+    const fetchWithoutAutoDecode = vi.spyOn(ProxyForwarder as never, "fetchWithoutAutoDecode");
+    fetchWithoutAutoDecode.mockImplementationOnce(async (url: string) => {
+      capturedUrl = url;
+      return new Response(
+        JSON.stringify({
+          object: "list",
+          data: [{ object: "embedding", embedding: [0.1, 0.2], index: 0 }],
+          model: "text-embedding-3-large",
+          usage: { prompt_tokens: 3, total_tokens: 3 },
+        }),
+        {
+          status: 200,
+          headers: { "content-type": "application/json" },
+        }
+      );
+    });
+
+    const { doForward } = ProxyForwarder as unknown as {
+      doForward: (session: ProxySession, provider: Provider, baseUrl: string) => Promise<Response>;
+    };
+
+    await doForward(session, provider, provider.url);
+
+    expect(capturedUrl).toBe("https://openai.example.com/openai/v1/embeddings");
+    expect(capturedUrl?.startsWith("https://mcp.example.com")).toBe(false);
+  });
+});