Jelajahi Sumber

fix: 修复供应商并发限制被超的竞态条件 Bug

## 问题根因

1. **Redis 数据结构冲突**:provider:*:active_sessions 仍为旧 Set 格式,导致计数失效返回 0
2. **竞态条件**:并发检查与 Session 追踪之间无原子性保障,多个请求同时通过检查
3. **异步追踪**:updateProvider 使用 void 异步执行,扩大竞态窗口

## 修复内容

### 1. 清理旧数据格式 (session-tracker.ts)
- 增强 SessionTracker.initialize() 自动清理旧 Set 数据
- 移除所有 countFromSet 兼容逻辑
- 计数方法遇到非 ZSET 类型直接删除并返回 0

### 2. 实现 Lua 原子脚本 (lua-scripts.ts)
- 新增 CHECK_AND_TRACK_SESSION Lua 脚本
- 原子性执行:清理过期 Session → 检查限制 → 追踪 Session
- 消除检查与追踪之间的竞态窗口

### 3. 重构 RateLimitService (rate-limit/service.ts)
- 新增 checkAndTrackProviderSession 原子性检查方法
- 使用 Lua 脚本保证并发安全
- 返回 {allowed, count, tracked} 三元组

### 4. 重构 ProxyProviderResolver (provider-selector.ts)
- 在选择供应商后立即执行原子性检查并追踪
- 移除 filterByLimits 中的并发检查(避免竞态)
- 移除异步 updateProvider 调用
- 检查失败直接返回 503 错误

## 修复效果

- 并发限制严格生效(limit=1 时最多 1 个活跃 Session)
- 消除竞态条件,保证原子性
- 代码更简洁(消除特殊情况)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>
ding113 4 bulan lalu
induk
melakukan
ac5106fc9d

+ 40 - 30
src/app/v1/_lib/proxy/provider-selector.ts

@@ -2,7 +2,6 @@ import type { Provider } from "@/types/provider";
 import { findProviderList, findProviderById } from "@/repository/provider";
 import { RateLimitService } from "@/lib/rate-limit";
 import { SessionManager } from "@/lib/session-manager";
-import { SessionTracker } from "@/lib/session-tracker";
 import { isCircuitOpen, getCircuitState } from "@/lib/circuit-breaker";
 import { ProxyLogger } from "./logger";
 import { ProxyResponses } from "./responses";
@@ -25,8 +24,33 @@ export class ProxyProviderResolver {
       session.setProvider(await ProxyProviderResolver.pickRandomProvider(session, [], targetProviderType));
     }
 
-    // 关键修复:选定供应商后立即记录到决策链并绑定到 session
-    if (session.provider) {
+    // 选定供应商后,进行原子性并发检查并追踪
+    if (session.provider && session.sessionId) {
+      const limit = session.provider.limitConcurrentSessions || 0;
+
+      // 使用原子性检查并追踪(解决竞态条件)
+      const checkResult = await RateLimitService.checkAndTrackProviderSession(
+        session.provider.id,
+        session.sessionId,
+        limit
+      );
+
+      if (!checkResult.allowed) {
+        // 并发限制失败
+        console.warn(
+          `[ProviderSelector] Provider ${session.provider.name} concurrent session limit exceeded (${checkResult.count}/${limit})`
+        );
+
+        // 记录失败
+        await ProxyLogger.logFailure(session, new Error(checkResult.reason || 'Session limit exceeded'));
+        return ProxyResponses.buildError(503, checkResult.reason || '供应商并发限制已达到');
+      }
+
+      console.debug(
+        `[ProviderSelector] ✅ Session tracked atomically: ${session.sessionId} → ${session.provider.name} (count=${checkResult.count})`
+      );
+
+      // 记录到决策链
       session.addProviderToChain(session.provider, {
         reason: 'initial_selection',
         selectionMethod,
@@ -34,22 +58,15 @@ export class ProxyProviderResolver {
       });
 
       // 绑定 session 到 provider(异步,不阻塞)
-      if (session.sessionId) {
-        void SessionManager.bindSessionToProvider(session.sessionId, session.provider.id);
-
-        // 更新 session tracker 的 provider 信息(同时刷新时间戳)
-        void SessionTracker.updateProvider(session.sessionId, session.provider.id).catch((error) => {
-          console.error('[ProviderSelector] Failed to update session tracker provider:', error);
-        });
-
-        // 更新 session 详细信息中的 provider 信息
-        void SessionManager.updateSessionProvider(session.sessionId, {
-          providerId: session.provider.id,
-          providerName: session.provider.name,
-        }).catch((error) => {
-          console.error('[ProviderSelector] Failed to update session provider info:', error);
-        });
-      }
+      void SessionManager.bindSessionToProvider(session.sessionId, session.provider.id);
+
+      // 更新 session 详细信息中的 provider 信息
+      void SessionManager.updateSessionProvider(session.sessionId, {
+        providerId: session.provider.id,
+        providerName: session.provider.name,
+      }).catch((error) => {
+        console.error('[ProviderSelector] Failed to update session provider info:', error);
+      });
 
       return null;
     }
@@ -200,6 +217,9 @@ export class ProxyProviderResolver {
 
   /**
    * 过滤超限供应商
+   *
+   * 注意:并发 Session 限制检查已移至原子性检查(ensure 方法中),
+   * 此处仅检查金额限制和熔断器状态
    */
   private static async filterByLimits(providers: Provider[]): Promise<Provider[]> {
     const results = await Promise.all(
@@ -222,17 +242,7 @@ export class ProxyProviderResolver {
           return null;
         }
 
-        // 2. 检查并发 Session 限制
-        const sessionCheck = await RateLimitService.checkSessionLimit(
-          p.id,
-          'provider',
-          p.limitConcurrentSessions || 0
-        );
-
-        if (!sessionCheck.allowed) {
-          console.debug(`[ProviderSelector] Provider ${p.id} session limit exceeded`);
-          return null;
-        }
+        // 并发 Session 限制已移至原子性检查(avoid race condition)
 
         return p;
       })

+ 66 - 2
src/lib/rate-limit/service.ts

@@ -1,5 +1,6 @@
 import { getRedisClient } from '@/lib/redis';
 import { SessionTracker } from '@/lib/session-tracker';
+import { CHECK_AND_TRACK_SESSION } from '@/lib/redis/lua-scripts';
 
 interface CostLimit {
   amount: number | null;
@@ -73,7 +74,10 @@ export class RateLimitService {
   }
 
   /**
-   * 检查并发 Session 限制
+   * 检查并发 Session 限制(仅检查,不追踪)
+   *
+   * 注意:此方法仅用于非供应商级别的限流检查(如 key 级)
+   * 供应商级别请使用 checkAndTrackProviderSession 保证原子性
    */
   static async checkSessionLimit(
     id: number,
@@ -85,7 +89,7 @@ export class RateLimitService {
     }
 
     try {
-      // 使用 SessionTracker 的统一计数逻辑(自动兼容 ZSET/Set)
+      // 使用 SessionTracker 的统一计数逻辑
       const count = type === 'key'
         ? await SessionTracker.getKeySessionCount(id)
         : await SessionTracker.getProviderSessionCount(id);
@@ -104,6 +108,66 @@ export class RateLimitService {
     }
   }
 
+  /**
+   * 原子性检查并追踪供应商 Session(解决竞态条件)
+   *
+   * 使用 Lua 脚本保证"检查 + 追踪"的原子性,防止并发请求同时通过限制检查
+   *
+   * @param providerId - Provider ID
+   * @param sessionId - Session ID
+   * @param limit - 并发限制
+   * @returns { allowed, count, tracked } - 是否允许、当前并发数、是否已追踪
+   */
+  static async checkAndTrackProviderSession(
+    providerId: number,
+    sessionId: string,
+    limit: number
+  ): Promise<{ allowed: boolean; count: number; tracked: boolean; reason?: string }> {
+    if (limit <= 0) {
+      return { allowed: true, count: 0, tracked: false };
+    }
+
+    if (!this.redis || this.redis.status !== 'ready') {
+      console.warn('[RateLimit] Redis not ready, Fail Open');
+      return { allowed: true, count: 0, tracked: false };
+    }
+
+    try {
+      const key = `provider:${providerId}:active_sessions`;
+      const now = Date.now();
+
+      // 执行 Lua 脚本:原子性检查 + 追踪
+      const result = await this.redis.eval(
+        CHECK_AND_TRACK_SESSION,
+        1,  // KEYS count
+        key,  // KEYS[1]
+        sessionId,  // ARGV[1]
+        limit.toString(),  // ARGV[2]
+        now.toString()  // ARGV[3]
+      ) as [number, number];
+
+      const [allowed, count] = result;
+
+      if (allowed === 0) {
+        return {
+          allowed: false,
+          count,
+          tracked: false,
+          reason: `供应商并发 Session 上限已达到(${count}/${limit})`,
+        };
+      }
+
+      return {
+        allowed: true,
+        count,
+        tracked: true,  // Lua 脚本中已追踪
+      };
+    } catch (error) {
+      console.error('[RateLimit] Atomic check-and-track failed:', error);
+      return { allowed: true, count: 0, tracked: false }; // Fail Open
+    }
+  }
+
   /**
    * 累加消费(请求结束后调用)
    */

+ 91 - 0
src/lib/redis/lua-scripts.ts

@@ -0,0 +1,91 @@
+/**
+ * Redis Lua 脚本集合
+ *
+ * 用于保证 Redis 操作的原子性
+ */
+
+/**
+ * 原子性检查并发限制 + 追踪 Session
+ *
+ * 功能:
+ * 1. 清理过期 session(5 分钟前)
+ * 2. 检查当前并发数是否超限
+ * 3. 如果未超限,追踪新 session(原子操作)
+ *
+ * KEYS[1]: provider:${providerId}:active_sessions
+ * ARGV[1]: sessionId
+ * ARGV[2]: limit(并发限制)
+ * ARGV[3]: now(当前时间戳,毫秒)
+ *
+ * 返回值:
+ * - {1, newCount} - 允许(追踪成功),返回新的并发数
+ * - {0, currentCount} - 拒绝(超限),返回当前并发数
+ */
+export const CHECK_AND_TRACK_SESSION = `
+local provider_key = KEYS[1]
+local session_id = ARGV[1]
+local limit = tonumber(ARGV[2])
+local now = tonumber(ARGV[3])
+local ttl = 300000  -- 5 分钟(毫秒)
+
+-- 1. 清理过期 session(5 分钟前)
+local five_minutes_ago = now - ttl
+redis.call('ZREMRANGEBYSCORE', provider_key, '-inf', five_minutes_ago)
+
+-- 2. 获取当前并发数
+local current_count = redis.call('ZCARD', provider_key)
+
+-- 3. 检查限制
+if limit > 0 and current_count >= limit then
+  return {0, current_count}  -- {allowed=false, current_count}
+end
+
+-- 4. 追踪 session(原子操作)
+redis.call('ZADD', provider_key, now, session_id)
+redis.call('EXPIRE', provider_key, 3600)  -- 1 小时兜底 TTL
+
+-- 5. 返回成功(新的并发数 = current_count + 1)
+return {1, current_count + 1}  -- {allowed=true, new_count}
+`;
+
+/**
+ * 批量检查多个供应商的并发限制
+ *
+ * KEYS: provider:${providerId}:active_sessions (多个)
+ * ARGV[1]: sessionId
+ * ARGV[2...]: limits(每个供应商的并发限制)
+ * ARGV[N]: now(当前时间戳,毫秒)
+ *
+ * 返回值:数组,每个元素对应一个供应商
+ * - {1, count} - 允许
+ * - {0, count} - 拒绝(超限)
+ */
+export const BATCH_CHECK_SESSION_LIMITS = `
+local session_id = ARGV[1]
+local now = tonumber(ARGV[#ARGV])
+local ttl = 300000  -- 5 分钟(毫秒)
+local five_minutes_ago = now - ttl
+
+local results = {}
+
+-- 遍历所有供应商 key
+for i = 1, #KEYS do
+  local provider_key = KEYS[i]
+  local limit = tonumber(ARGV[i + 1])  -- ARGV[2]...ARGV[N-1]
+
+  -- 清理过期 session
+  redis.call('ZREMRANGEBYSCORE', provider_key, '-inf', five_minutes_ago)
+
+  -- 获取当前并发数
+  local current_count = redis.call('ZCARD', provider_key)
+
+  -- 检查限制
+  if limit > 0 and current_count >= limit then
+    table.insert(results, {0, current_count})  -- 拒绝
+  else
+    table.insert(results, {1, current_count})  -- 允许
+  end
+end
+
+return results
+`;

+ 36 - 87
src/lib/session-tracker.ts

@@ -20,8 +20,8 @@ export class SessionTracker {
   /**
    * 初始化 SessionTracker,自动清理旧格式数据
    *
-   * 应在应用启动时调用一次,清理重构前的 Set 类型数据,
-   * 确保所有集合都是 ZSET 格式
+   * 应在应用启动时调用一次,清理 global:active_sessions 的旧 Set 数据。
+   * 其他 key(provider:*、key:*)在运行时自动清理
    */
   static async initialize(): Promise<void> {
     const redis = getRedisClient();
@@ -37,15 +37,12 @@ export class SessionTracker {
       if (exists === 1) {
         const type = await redis.type(key);
 
-        if (type === 'set') {
-          console.warn(`[SessionTracker] Found legacy Set: ${key}, migrating to ZSET...`);
+        if (type !== 'zset') {
+          console.warn(`[SessionTracker] Found legacy format: ${key} (type=${type}), deleting...`);
           await redis.del(key);
-          console.info(`[SessionTracker] ✅ Successfully migrated ${key} to ZSET format`);
-        } else if (type === 'zset') {
-          console.debug(`[SessionTracker] ${key} is already ZSET format, no migration needed`);
+          console.info(`[SessionTracker] ✅ Deleted legacy ${key}`);
         } else {
-          console.warn(`[SessionTracker] Unexpected type for ${key}: ${type}, deleting...`);
-          await redis.del(key);
+          console.debug(`[SessionTracker] ${key} is already ZSET format`);
         }
       } else {
         console.debug(`[SessionTracker] ${key} does not exist, will be created on first use`);
@@ -198,8 +195,6 @@ export class SessionTracker {
   /**
    * 获取全局活跃 session 计数
    *
-   * 自动兼容新旧格式(ZSET/Set)
-   *
    * @returns 活跃 session 数量
    */
   static async getGlobalSessionCount(): Promise<number> {
@@ -213,14 +208,13 @@ export class SessionTracker {
       if (exists === 1) {
         const type = await redis.type(key);
 
-        if (type === 'zset') {
-          // 新格式:从 ZSET 读取
-          return await this.countFromZSet(key);
-        } else {
-          // 旧格式:从 Set 读取(兼容模式)
-          console.debug('[SessionTracker] Using legacy Set format (will expire in 5 min)');
-          return await this.countFromSet(key);
+        if (type !== 'zset') {
+          console.warn(`[SessionTracker] ${key} is not ZSET (type=${type}), deleting...`);
+          await redis.del(key);
+          return 0;
         }
+
+        return await this.countFromZSet(key);
       }
 
       return 0;
@@ -247,12 +241,13 @@ export class SessionTracker {
       if (exists === 1) {
         const type = await redis.type(key);
 
-        if (type === 'zset') {
-          return await this.countFromZSet(key);
-        } else {
-          console.debug(`[SessionTracker] Key ${keyId}: Using legacy Set format`);
-          return await this.countFromSet(key);
+        if (type !== 'zset') {
+          console.warn(`[SessionTracker] ${key} is not ZSET (type=${type}), deleting...`);
+          await redis.del(key);
+          return 0;
         }
+
+        return await this.countFromZSet(key);
       }
 
       return 0;
@@ -279,12 +274,13 @@ export class SessionTracker {
       if (exists === 1) {
         const type = await redis.type(key);
 
-        if (type === 'zset') {
-          return await this.countFromZSet(key);
-        } else {
-          console.debug(`[SessionTracker] Provider ${providerId}: Using legacy Set format`);
-          return await this.countFromSet(key);
+        if (type !== 'zset') {
+          console.warn(`[SessionTracker] ${key} is not ZSET (type=${type}), deleting...`);
+          await redis.del(key);
+          return 0;
         }
+
+        return await this.countFromZSet(key);
       }
 
       return 0;
@@ -310,20 +306,20 @@ export class SessionTracker {
       if (exists === 1) {
         const type = await redis.type(key);
 
-        if (type === 'zset') {
-          // 新格式:从 ZSET 读取
-          const now = Date.now();
-          const fiveMinutesAgo = now - this.SESSION_TTL;
+        if (type !== 'zset') {
+          console.warn(`[SessionTracker] ${key} is not ZSET (type=${type}), deleting...`);
+          await redis.del(key);
+          return [];
+        }
+
+        const now = Date.now();
+        const fiveMinutesAgo = now - this.SESSION_TTL;
 
-          // 清理过期 session
-          await redis.zremrangebyscore(key, '-inf', fiveMinutesAgo);
+        // 清理过期 session
+        await redis.zremrangebyscore(key, '-inf', fiveMinutesAgo);
 
-          // 获取剩余的 session ID
-          return await redis.zrange(key, 0, -1);
-        } else {
-          // 旧格式:从 Set 读取
-          return await redis.smembers(key);
-        }
+        // 获取剩余的 session ID
+        return await redis.zrange(key, 0, -1);
       }
 
       return [];
@@ -386,51 +382,4 @@ export class SessionTracker {
     }
   }
 
-  /**
-   * 从 Set 计数(旧格式 - 兼容模式)
-   *
-   * 实现步骤:
-   * 1. SMEMBERS 获取所有 session ID
-   * 2. 批量 EXISTS 验证 session:${sessionId}:info 是否存在
-   * 3. 统计真实存在的 session
-   *
-   * 注意:这是兼容旧数据的方法,5 分钟后旧数据自动过期,将全部切换到 ZSET
-   *
-   * @param key - Redis key
-   * @returns 有效 session 数量
-   */
-  private static async countFromSet(key: string): Promise<number> {
-    const redis = getRedisClient();
-    if (!redis || redis.status !== 'ready') return 0;
-
-    try {
-      // 1. 获取所有 session ID
-      const sessionIds = await redis.smembers(key);
-      if (sessionIds.length === 0) return 0;
-
-      // 2. 批量验证 info 是否存在
-      const pipeline = redis.pipeline();
-      for (const sessionId of sessionIds) {
-        pipeline.exists(`session:${sessionId}:info`);
-      }
-      const results = await pipeline.exec();
-      if (!results) return 0;
-
-      // 3. 统计有效 session
-      let count = 0;
-      for (const result of results) {
-        if (result && result[0] === null && result[1] === 1) {
-          count++;
-        }
-      }
-
-      console.debug(
-        `[SessionTracker] Set ${key} (legacy): ${count} valid sessions (from ${sessionIds.length} total)`
-      );
-      return count;
-    } catch (error) {
-      console.error('[SessionTracker] Failed to count from Set:', error);
-      return 0;
-    }
-  }
 }