Browse Source

fix: remove claude-* model prefix routing restriction (#832)

Simplify providerSupportsModel() to treat all provider types uniformly.
Previously, claude-* models were hardcoded to only route to claude/claude-auth
providers, even when other providers explicitly declared them in allowedModels.
Now the logic is: explicit allowedModels/modelRedirects match -> accept;
empty allowedModels -> wildcard; otherwise reject. Format compatibility
remains enforced by checkFormatProviderTypeCompatibility independently.
ding113 1 month ago
parent
commit
a82e994376

+ 14 - 52
src/app/v1/_lib/proxy/provider-selector.ts

@@ -98,71 +98,33 @@ function checkProviderGroupMatch(providerGroupTag: string | null, userGroups: st
 /**
  * 检查供应商是否支持指定模型(用于调度器匹配)
  *
- * 核心逻辑:
- * 1. Claude 模型请求 (claude-*):
- *    - Anthropic 提供商:根据 allowedModels 白名单判断
- *    - 非 Anthropic 提供商:不支持 claude-* 模型调度
+ * 核心逻辑(统一所有供应商类型)
+ * 1. 显式声明优先:allowedModels 包含或 modelRedirects 包含 -> 支持
+ * 2. 未设置 allowedModels(null 或空数组):接受任意模型(格式兼容性由 checkFormatProviderTypeCompatibility 保证)
+ * 3. 设置了 allowedModels 但不包含该模型 -> 不支持
  *
- * 2. 非 Claude 模型请求 (gpt-*, gemini-*, 或其他任意模型):
- *    - Anthropic 提供商:不支持(仅支持 Claude 模型)
- *    - 非 Anthropic 提供商(codex, gemini-cli, openai-compatible):
- *      a. 如果未设置 allowedModels(null 或空数组):接受任意模型
- *      b. 如果设置了 allowedModels:检查模型是否在声明列表中,或有模型重定向配置
- *      注意:allowedModels 是声明性列表(用户可填写任意字符串),用于调度器匹配,不是真实模型校验
+ * 注意:allowedModels 是声明性列表(用户可填写任意字符串),用于调度器匹配,不是真实模型校验。
+ * 格式兼容性(如 claude 格式请求只路由到 claude 类型供应商)由 checkFormatProviderTypeCompatibility 独立保证。
  *
  * @param provider - 供应商信息
  * @param requestedModel - 用户请求的模型名称
  * @returns 是否支持该模型(用于调度器筛选)
  */
 function providerSupportsModel(provider: Provider, requestedModel: string): boolean {
-  const isClaudeModel = requestedModel.startsWith("claude-");
-  const isClaudeProvider =
-    provider.providerType === "claude" || provider.providerType === "claude-auth";
-
-  // Case 1: Claude 模型请求
-  if (isClaudeModel) {
-    // 1a. Anthropic 提供商
-    if (isClaudeProvider) {
-      // 未设置 allowedModels 或为空数组:允许所有 claude 模型
-      if (!provider.allowedModels || provider.allowedModels.length === 0) {
-        return true;
-      }
-      // 检查白名单
-      return provider.allowedModels.includes(requestedModel);
-    }
-
-    // 1b. 非 Anthropic 提供商不支持 Claude 模型调度
-    return false;
-  }
-
-  // Case 2: 非 Claude 模型请求(gpt-*, gemini-*, 或其他任意模型)
-  // 2a. 优先检查显式声明(支持跨类型代理)
-  // 原因:允许 Claude 类型供应商通过 allowedModels/modelRedirects 声明支持非 Claude 模型
-  // 场景:Claude 供应商配置模型重定向,将 gemini-* 请求转发到真实的 Gemini 上游
-  const explicitlyDeclared = !!(
-    provider.allowedModels?.includes(requestedModel) || provider.modelRedirects?.[requestedModel]
-  );
-
-  if (explicitlyDeclared) {
-    return true; // 显式声明优先级最高,允许跨类型代理
-  }
-
-  // 2b. Anthropic 提供商不支持非声明的非 Claude 模型
-  // 保护机制:防止将非 Claude 模型误路由到 Anthropic API
-  if (isClaudeProvider) {
-    return false;
+  // 1. 显式声明优先(allowedModels 或 modelRedirects)
+  if (
+    provider.allowedModels?.includes(requestedModel) ||
+    provider.modelRedirects?.[requestedModel]
+  ) {
+    return true;
   }
 
-  // 2c. 非 Anthropic 提供商(codex, gemini, gemini-cli, openai-compatible)
-  // allowedModels 是声明列表,用于调度器匹配提供商
-  // 用户可以手动填写任意模型名称(不限于真实模型),用于声明该提供商"支持"哪些模型
-
-  // 未设置 allowedModels 或为空数组:接受任意模型(由上游提供商判断)
+  // 2. 未设置 allowedModels(null 或空数组):接受任意模型
   if (!provider.allowedModels || provider.allowedModels.length === 0) {
     return true;
   }
 
-  // 不在声明列表中且无重定向配置(前面已检查过 explicitlyDeclared)
+  // 3. 设置了 allowedModels 但不包含该模型,且无 modelRedirects
   return false;
 }
 

+ 265 - 0
tests/unit/proxy/provider-selector-cross-type-model.test.ts

@@ -0,0 +1,265 @@
+import { beforeEach, describe, expect, test, vi } from "vitest";
+import type { Provider } from "@/types/provider";
+
+const circuitBreakerMocks = vi.hoisted(() => ({
+  isCircuitOpen: vi.fn(async () => false),
+  getCircuitState: vi.fn(() => "closed"),
+}));
+
+vi.mock("@/lib/circuit-breaker", () => circuitBreakerMocks);
+
+const vendorTypeCircuitMocks = vi.hoisted(() => ({
+  isVendorTypeCircuitOpen: vi.fn(async () => false),
+}));
+
+vi.mock("@/lib/vendor-type-circuit-breaker", () => vendorTypeCircuitMocks);
+
+const sessionManagerMocks = vi.hoisted(() => ({
+  SessionManager: {
+    getSessionProvider: vi.fn(async () => null as number | null),
+    clearSessionProvider: vi.fn(async () => undefined),
+  },
+}));
+
+vi.mock("@/lib/session-manager", () => sessionManagerMocks);
+
+const providerRepositoryMocks = vi.hoisted(() => ({
+  findProviderById: vi.fn(async () => null as Provider | null),
+  findAllProviders: vi.fn(async () => [] as Provider[]),
+}));
+
+vi.mock("@/repository/provider", () => providerRepositoryMocks);
+
+const rateLimitMocks = vi.hoisted(() => ({
+  RateLimitService: {
+    checkCostLimitsWithLease: vi.fn(async () => ({ allowed: true })),
+    checkTotalCostLimit: vi.fn(async () => ({ allowed: true, current: 0 })),
+  },
+}));
+
+vi.mock("@/lib/rate-limit", () => rateLimitMocks);
+
+beforeEach(() => {
+  vi.resetAllMocks();
+});
+
+function createProvider(overrides: Partial<Provider> = {}): Provider {
+  return {
+    id: 1,
+    name: "test-provider",
+    isEnabled: true,
+    providerType: "openai-compatible",
+    groupTag: null,
+    weight: 1,
+    priority: 0,
+    costMultiplier: 1,
+    allowedModels: null,
+    providerVendorId: null,
+    limit5hUsd: null,
+    limitDailyUsd: null,
+    dailyResetMode: "fixed",
+    dailyResetTime: "00:00",
+    limitWeeklyUsd: null,
+    limitMonthlyUsd: null,
+    limitTotalUsd: null,
+    totalCostResetAt: null,
+    limitConcurrentSessions: 0,
+    ...overrides,
+  } as unknown as Provider;
+}
+
+describe("providerSupportsModel - cross-type model routing (#832)", () => {
+  test("openai-compatible provider with claude model in allowedModels should match", async () => {
+    const { ProxyProviderResolver } = await import("@/app/v1/_lib/proxy/provider-selector");
+
+    const provider = createProvider({
+      id: 10,
+      providerType: "openai-compatible",
+      allowedModels: ["claude-opus-4-6"],
+    });
+
+    sessionManagerMocks.SessionManager.getSessionProvider.mockResolvedValueOnce(10);
+    providerRepositoryMocks.findProviderById.mockResolvedValueOnce(provider);
+    rateLimitMocks.RateLimitService.checkCostLimitsWithLease.mockResolvedValueOnce({
+      allowed: true,
+    });
+    rateLimitMocks.RateLimitService.checkTotalCostLimit.mockResolvedValueOnce({
+      allowed: true,
+      current: 0,
+    });
+
+    const session = {
+      sessionId: "cross-type-1",
+      shouldReuseProvider: () => true,
+      getOriginalModel: () => "claude-opus-4-6",
+      authState: null,
+      getCurrentModel: () => null,
+    } as any;
+
+    const result = await (ProxyProviderResolver as any).findReusable(session);
+
+    expect(result).not.toBeNull();
+    expect(result?.id).toBe(10);
+  });
+
+  test("openai-compatible provider with empty allowedModels should match any model (wildcard)", async () => {
+    const { ProxyProviderResolver } = await import("@/app/v1/_lib/proxy/provider-selector");
+
+    const provider = createProvider({
+      id: 11,
+      providerType: "openai-compatible",
+      allowedModels: null,
+    });
+
+    sessionManagerMocks.SessionManager.getSessionProvider.mockResolvedValueOnce(11);
+    providerRepositoryMocks.findProviderById.mockResolvedValueOnce(provider);
+    rateLimitMocks.RateLimitService.checkCostLimitsWithLease.mockResolvedValueOnce({
+      allowed: true,
+    });
+    rateLimitMocks.RateLimitService.checkTotalCostLimit.mockResolvedValueOnce({
+      allowed: true,
+      current: 0,
+    });
+
+    const session = {
+      sessionId: "cross-type-2",
+      shouldReuseProvider: () => true,
+      getOriginalModel: () => "claude-sonnet-4-5-20250929",
+      authState: null,
+      getCurrentModel: () => null,
+    } as any;
+
+    const result = await (ProxyProviderResolver as any).findReusable(session);
+
+    expect(result).not.toBeNull();
+    expect(result?.id).toBe(11);
+  });
+
+  test("openai-compatible provider with allowedModels NOT containing the model should not match", async () => {
+    const { ProxyProviderResolver } = await import("@/app/v1/_lib/proxy/provider-selector");
+
+    const provider = createProvider({
+      id: 12,
+      providerType: "openai-compatible",
+      allowedModels: ["gpt-4o", "gpt-4o-mini"],
+    });
+
+    sessionManagerMocks.SessionManager.getSessionProvider.mockResolvedValueOnce(12);
+    providerRepositoryMocks.findProviderById.mockResolvedValueOnce(provider);
+
+    const session = {
+      sessionId: "cross-type-3",
+      shouldReuseProvider: () => true,
+      getOriginalModel: () => "claude-opus-4-6",
+      authState: null,
+      getCurrentModel: () => null,
+    } as any;
+
+    const result = await (ProxyProviderResolver as any).findReusable(session);
+
+    expect(result).toBeNull();
+    expect(sessionManagerMocks.SessionManager.clearSessionProvider).toHaveBeenCalledWith(
+      "cross-type-3"
+    );
+  });
+
+  test("claude provider with empty allowedModels should match any model (wildcard)", async () => {
+    const { ProxyProviderResolver } = await import("@/app/v1/_lib/proxy/provider-selector");
+
+    const provider = createProvider({
+      id: 13,
+      providerType: "claude",
+      allowedModels: null,
+    });
+
+    sessionManagerMocks.SessionManager.getSessionProvider.mockResolvedValueOnce(13);
+    providerRepositoryMocks.findProviderById.mockResolvedValueOnce(provider);
+    rateLimitMocks.RateLimitService.checkCostLimitsWithLease.mockResolvedValueOnce({
+      allowed: true,
+    });
+    rateLimitMocks.RateLimitService.checkTotalCostLimit.mockResolvedValueOnce({
+      allowed: true,
+      current: 0,
+    });
+
+    const session = {
+      sessionId: "cross-type-4",
+      shouldReuseProvider: () => true,
+      getOriginalModel: () => "gpt-4o",
+      authState: null,
+      getCurrentModel: () => null,
+    } as any;
+
+    const result = await (ProxyProviderResolver as any).findReusable(session);
+
+    expect(result).not.toBeNull();
+    expect(result?.id).toBe(13);
+  });
+
+  test("claude provider with non-claude model in allowedModels should match (explicit declaration)", async () => {
+    const { ProxyProviderResolver } = await import("@/app/v1/_lib/proxy/provider-selector");
+
+    const provider = createProvider({
+      id: 14,
+      providerType: "claude",
+      allowedModels: ["gemini-2.5-pro"],
+    });
+
+    sessionManagerMocks.SessionManager.getSessionProvider.mockResolvedValueOnce(14);
+    providerRepositoryMocks.findProviderById.mockResolvedValueOnce(provider);
+    rateLimitMocks.RateLimitService.checkCostLimitsWithLease.mockResolvedValueOnce({
+      allowed: true,
+    });
+    rateLimitMocks.RateLimitService.checkTotalCostLimit.mockResolvedValueOnce({
+      allowed: true,
+      current: 0,
+    });
+
+    const session = {
+      sessionId: "cross-type-5",
+      shouldReuseProvider: () => true,
+      getOriginalModel: () => "gemini-2.5-pro",
+      authState: null,
+      getCurrentModel: () => null,
+    } as any;
+
+    const result = await (ProxyProviderResolver as any).findReusable(session);
+
+    expect(result).not.toBeNull();
+    expect(result?.id).toBe(14);
+  });
+
+  test("any provider with modelRedirects containing the model should match", async () => {
+    const { ProxyProviderResolver } = await import("@/app/v1/_lib/proxy/provider-selector");
+
+    const provider = createProvider({
+      id: 15,
+      providerType: "openai-compatible",
+      allowedModels: ["gpt-4o"],
+      modelRedirects: { "claude-opus-4-6": "custom-opus" },
+    });
+
+    sessionManagerMocks.SessionManager.getSessionProvider.mockResolvedValueOnce(15);
+    providerRepositoryMocks.findProviderById.mockResolvedValueOnce(provider);
+    rateLimitMocks.RateLimitService.checkCostLimitsWithLease.mockResolvedValueOnce({
+      allowed: true,
+    });
+    rateLimitMocks.RateLimitService.checkTotalCostLimit.mockResolvedValueOnce({
+      allowed: true,
+      current: 0,
+    });
+
+    const session = {
+      sessionId: "cross-type-6",
+      shouldReuseProvider: () => true,
+      getOriginalModel: () => "claude-opus-4-6",
+      authState: null,
+      getCurrentModel: () => null,
+    } as any;
+
+    const result = await (ProxyProviderResolver as any).findReusable(session);
+
+    expect(result).not.toBeNull();
+    expect(result?.id).toBe(15);
+  });
+});