Browse Source

feat(proxy): introduce configurable Guard Pipeline system

- Add RequestType enum (CHAT, COUNT_TOKENS)
- Define GuardConfig, GuardPipeline, GuardStep interfaces and adapters
- Implement GuardPipelineBuilder to assemble dynamic guard chains
- Provide presets: CHAT_PIPELINE and COUNT_TOKENS_PIPELINE
- Wire proxy-handler to use pipeline and skip concurrency for count_tokens
- Add ProxySession.isCountTokensRequest() helper
ding113 4 months ago
parent
commit
072d3fc893

+ 11 - 56
src/app/v1/_lib/proxy-handler.ts

@@ -1,13 +1,7 @@
 import type { Context } from "hono";
 import { logger } from "@/lib/logger";
 import { ProxySession } from "./proxy/session";
-import { ProxyAuthenticator } from "./proxy/auth-guard";
-import { ProxyVersionGuard } from "./proxy/version-guard";
-import { ProxySessionGuard } from "./proxy/session-guard";
-import { ProxySensitiveWordGuard } from "./proxy/sensitive-word-guard";
-import { ProxyRateLimitGuard } from "./proxy/rate-limit-guard";
-import { ProxyProviderResolver } from "./proxy/provider-selector";
-import { ProxyMessageService } from "./proxy/message-service";
+import { GuardPipelineBuilder, RequestType } from "./proxy/guard-pipeline";
 import { ProxyForwarder } from "./proxy/forwarder";
 import { ProxyResponseHandler } from "./proxy/response-handler";
 import { ProxyErrorHandler } from "./proxy/error-handler";
@@ -18,55 +12,16 @@ export async function handleProxyRequest(c: Context): Promise<Response> {
   const session = await ProxySession.fromContext(c);
 
   try {
-    // 1. 认证检查
-    const unauthorized = await ProxyAuthenticator.ensure(session);
-    if (unauthorized) {
-      return unauthorized;
-    }
-
-    // 2. 版本检查(在认证后、Session 分配前)
-    const upgradeRequired = await ProxyVersionGuard.ensure(session);
-    if (upgradeRequired) {
-      return upgradeRequired;
-    }
-
-    // 3. 探测请求拦截:立即返回,不执行任何后续逻辑
-    if (session.isProbeRequest()) {
-      logger.debug("[ProxyHandler] Probe request detected, returning mock success", {
-        messagesCount: session.getMessagesLength(),
-      });
-      return new Response(JSON.stringify({ input_tokens: 0 }), {
-        status: 200,
-        headers: { "Content-Type": "application/json" },
-      });
-    }
-
-    // 4. Session 分配
-    await ProxySessionGuard.ensure(session);
-
-    // 5. 敏感词检查(在计费之前)
-    const blockedBySensitiveWord = await ProxySensitiveWordGuard.ensure(session);
-    if (blockedBySensitiveWord) {
-      return blockedBySensitiveWord;
-    }
-
-    // 6. 限流检查
-    const rateLimited = await ProxyRateLimitGuard.ensure(session);
-    if (rateLimited) {
-      return rateLimited;
-    }
-
-    // 7. 供应商选择
-    const providerUnavailable = await ProxyProviderResolver.ensure(session);
-    if (providerUnavailable) {
-      return providerUnavailable;
-    }
+    // Decide request type and build configured guard pipeline
+    const type = session.isCountTokensRequest() ? RequestType.COUNT_TOKENS : RequestType.CHAT;
+    const pipeline = GuardPipelineBuilder.fromRequestType(type);
 
-    // 8. 创建消息上下文(正常请求才写入数据库)
-    await ProxyMessageService.ensureContext(session);
+    // Run guard chain; may return early Response
+    const early = await pipeline.run(session);
+    if (early) return early;
 
-    // 9. 增加并发计数(在所有检查通过后,请求开始前)
-    if (session.sessionId) {
+    // 9. 增加并发计数(在所有检查通过后,请求开始前)- 跳过 count_tokens
+    if (session.sessionId && !session.isCountTokensRequest()) {
       await SessionTracker.incrementConcurrentCount(session.sessionId);
     }
 
@@ -90,8 +45,8 @@ export async function handleProxyRequest(c: Context): Promise<Response> {
     logger.error("Proxy handler error:", error);
     return await ProxyErrorHandler.handle(session, error);
   } finally {
-    // 11. 减少并发计数(确保无论成功失败都执行)
-    if (session.sessionId) {
+    // 11. 减少并发计数(确保无论成功失败都执行)- 跳过 count_tokens
+    if (session.sessionId && !session.isCountTokensRequest()) {
       await SessionTracker.decrementConcurrentCount(session.sessionId);
     }
   }

+ 139 - 0
src/app/v1/_lib/proxy/guard-pipeline.ts

@@ -0,0 +1,139 @@
+import type { ProxySession } from './session';
+import { ProxyAuthenticator } from './auth-guard';
+import { ProxyVersionGuard } from './version-guard';
+import { ProxySessionGuard } from './session-guard';
+import { ProxySensitiveWordGuard } from './sensitive-word-guard';
+import { ProxyRateLimitGuard } from './rate-limit-guard';
+import { ProxyProviderResolver } from './provider-selector';
+import { ProxyMessageService } from './message-service';
+
+// Request type classification for pipeline presets
+export enum RequestType {
+  CHAT = 'CHAT',
+  COUNT_TOKENS = 'COUNT_TOKENS',
+}
+
+// A single guard step that can mutate session or produce an early Response
+export interface GuardStep {
+  name: string;
+  execute(session: ProxySession): Promise<Response | null>;
+}
+
+// Pipeline configuration describes an ordered list of step keys
+export type GuardStepKey =
+  | 'auth'
+  | 'version'
+  | 'probe'
+  | 'session'
+  | 'sensitive'
+  | 'rateLimit'
+  | 'provider'
+  | 'messageContext';
+
+export interface GuardConfig {
+  steps: GuardStepKey[];
+}
+
+export interface GuardPipeline {
+  run(session: ProxySession): Promise<Response | null>;
+}
+
+// Concrete GuardStep implementations (adapters over existing guards)
+const Steps: Record<GuardStepKey, GuardStep> = {
+  auth: {
+    name: 'auth',
+    async execute(session) {
+      return ProxyAuthenticator.ensure(session);
+    },
+  },
+  version: {
+    name: 'version',
+    async execute(session) {
+      return ProxyVersionGuard.ensure(session);
+    },
+  },
+  probe: {
+    name: 'probe',
+    async execute(session) {
+      if (session.isProbeRequest()) {
+        return new Response(JSON.stringify({ input_tokens: 0 }), {
+          status: 200,
+          headers: { 'Content-Type': 'application/json' },
+        });
+      }
+      return null;
+    },
+  },
+  session: {
+    name: 'session',
+    async execute(session) {
+      await ProxySessionGuard.ensure(session);
+      return null;
+    },
+  },
+  sensitive: {
+    name: 'sensitive',
+    async execute(session) {
+      return ProxySensitiveWordGuard.ensure(session);
+    },
+  },
+  rateLimit: {
+    name: 'rateLimit',
+    async execute(session) {
+      return ProxyRateLimitGuard.ensure(session);
+    },
+  },
+  provider: {
+    name: 'provider',
+    async execute(session) {
+      return ProxyProviderResolver.ensure(session);
+    },
+  },
+  messageContext: {
+    name: 'messageContext',
+    async execute(session) {
+      await ProxyMessageService.ensureContext(session);
+      return null;
+    },
+  },
+};
+
+export class GuardPipelineBuilder {
+  // Assemble a pipeline from a configuration
+  static build(config: GuardConfig): GuardPipeline {
+    const steps: GuardStep[] = config.steps.map((k) => Steps[k]);
+
+    return {
+      async run(session: ProxySession): Promise<Response | null> {
+        for (const step of steps) {
+          const res = await step.execute(session);
+          if (res) return res; // early exit
+        }
+        return null;
+      },
+    };
+  }
+
+  // Convenience: build a pipeline from preset request type
+  static fromRequestType(type: RequestType): GuardPipeline {
+    switch (type) {
+      case RequestType.COUNT_TOKENS:
+        return GuardPipelineBuilder.build(COUNT_TOKENS_PIPELINE);
+      case RequestType.CHAT:
+      default:
+        return GuardPipelineBuilder.build(CHAT_PIPELINE);
+    }
+  }
+}
+
+// Preset configurations
+export const CHAT_PIPELINE: GuardConfig = {
+  // Full guard chain for normal chat requests
+  steps: ['auth', 'version', 'probe', 'session', 'sensitive', 'rateLimit', 'provider', 'messageContext'],
+};
+
+export const COUNT_TOKENS_PIPELINE: GuardConfig = {
+  // Minimal chain for count_tokens: no session, no sensitive, no rate limit, no message logging
+  steps: ['auth', 'version', 'probe', 'provider'],
+};
+

+ 9 - 0
src/app/v1/_lib/proxy/session.ts

@@ -308,6 +308,15 @@ export class ProxySession {
     }
   }
 
+  /**
+   * 是否为 count_tokens 请求端点
+   * - 依据 URL pathname 判断:/v1/messages/count_tokens
+   */
+  isCountTokensRequest(): boolean {
+    const endpoint = this.getEndpoint();
+    return endpoint === "/v1/messages/count_tokens";
+  }
+
   /**
    * 设置原始模型(在重定向前调用)
    * 只能设置一次,避免多次重定向覆盖