Browse Source

Refactor: Optimize the request rate limiting for ModelRequestRateLimitCount.
Reason: The original steps 1 and 3 in the redisRateLimitHandler method were not atomic, leading to poor precision under high concurrent requests. For example, with a rate limit set to 60, sending 200 concurrent requests would result in none being blocked, whereas theoretically around 140 should be intercepted.
Solution: I chose not to merge steps 1 and 3 into a single Lua script because a single atomic operation involving read, write, and delete operations could suffer from performance issues under high concurrency. Instead, I implemented a token bucket algorithm to optimize this, reducing the atomic operation to just read and write steps while significantly decreasing the memory footprint.

霍雨佳 8 months ago
parent
commit
eb75ff232f
3 changed files with 160 additions and 14 deletions
  1. 94 0
      common/limiter/limiter.go
  2. 44 0
      common/limiter/lua/rate_limit.lua
  3. 22 14
      middleware/model-rate-limit.go

+ 94 - 0
common/limiter/limiter.go

@@ -0,0 +1,94 @@
+package limiter
+
+import (
+	"context"
+	_ "embed"
+	"fmt"
+	"github.com/go-redis/redis/v8"
+	"sync"
+)
+
+//go:embed lua/rate_limit.lua
+var rateLimitScript string
+
+type RedisLimiter struct {
+	client         *redis.Client
+	limitScriptSHA string
+}
+
+var (
+	instance *RedisLimiter
+	once     sync.Once
+)
+
+func New(ctx context.Context, r *redis.Client) *RedisLimiter {
+	once.Do(func() {
+		client := r
+		_, err := client.Ping(ctx).Result()
+		if err != nil {
+			panic(err) // 或者处理连接错误
+		}
+		// 预加载脚本
+		limitSHA, err := client.ScriptLoad(ctx, rateLimitScript).Result()
+		if err != nil {
+			fmt.Println(err)
+		}
+
+		instance = &RedisLimiter{
+			client:         client,
+			limitScriptSHA: limitSHA,
+		}
+	})
+
+	return instance
+}
+
+func (rl *RedisLimiter) Allow(ctx context.Context, key string, opts ...Option) (bool, error) {
+	// 默认配置
+	config := &Config{
+		Capacity:  10,
+		Rate:      1,
+		Requested: 1,
+	}
+
+	// 应用选项模式
+	for _, opt := range opts {
+		opt(config)
+	}
+
+	// 执行限流
+	result, err := rl.client.EvalSha(
+		ctx,
+		rl.limitScriptSHA,
+		[]string{key},
+		config.Requested,
+		config.Rate,
+		config.Capacity,
+	).Int()
+
+	if err != nil {
+		return false, fmt.Errorf("rate limit failed: %w", err)
+	}
+	return result == 1, nil
+}
+
+// Config 配置选项模式
+type Config struct {
+	Capacity  int64
+	Rate      int64
+	Requested int64
+}
+
+type Option func(*Config)
+
+func WithCapacity(c int64) Option {
+	return func(cfg *Config) { cfg.Capacity = c }
+}
+
+func WithRate(r int64) Option {
+	return func(cfg *Config) { cfg.Rate = r }
+}
+
+func WithRequested(n int64) Option {
+	return func(cfg *Config) { cfg.Requested = n }
+}

+ 44 - 0
common/limiter/lua/rate_limit.lua

@@ -0,0 +1,44 @@
+-- 令牌桶限流器
+-- KEYS[1]: 限流器唯一标识
+-- ARGV[1]: 请求令牌数 (通常为1)
+-- ARGV[2]: 令牌生成速率 (每秒)
+-- ARGV[3]: 桶容量
+
+local key = KEYS[1]
+local requested = tonumber(ARGV[1])
+local rate = tonumber(ARGV[2])
+local capacity = tonumber(ARGV[3])
+
+-- 获取当前时间(Redis服务器时间)
+local now = redis.call('TIME')
+local nowInSeconds = tonumber(now[1])
+
+-- 获取桶状态
+local bucket = redis.call('HMGET', key, 'tokens', 'last_time')
+local tokens = tonumber(bucket[1])
+local last_time = tonumber(bucket[2])
+
+-- 初始化桶(首次请求或过期)
+if not tokens or not last_time then
+    tokens = capacity
+    last_time = nowInSeconds
+else
+    -- 计算新增令牌
+    local elapsed = nowInSeconds - last_time
+    local add_tokens = elapsed * rate
+    tokens = math.min(capacity, tokens + add_tokens)
+    last_time = nowInSeconds
+end
+
+-- 判断是否允许请求
+local allowed = false
+if tokens >= requested then
+    tokens = tokens - requested
+    allowed = true
+end
+
+---- 更新桶状态并设置过期时间
+redis.call('HMSET', key, 'tokens', tokens, 'last_time', last_time)
+--redis.call('EXPIRE', key, math.ceil(capacity / rate) + 60) -- 适当延长过期时间
+
+return allowed and 1 or 0

+ 22 - 14
middleware/model-rate-limit.go

@@ -5,6 +5,7 @@ import (
 	"fmt"
 	"net/http"
 	"one-api/common"
+	"one-api/common/limiter"
 	"one-api/setting"
 	"strconv"
 	"time"
@@ -78,34 +79,41 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g
 		ctx := context.Background()
 		rdb := common.RDB
 
-		// 1. 检查总请求数限制(当totalMaxCount为0时会自动跳过)
-		totalKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitCountMark, userId)
-		allowed, err := checkRedisRateLimit(ctx, rdb, totalKey, totalMaxCount, duration)
+		// 1. 检查成功请求数限制
+		successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId)
+		allowed, err := checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration)
 		if err != nil {
-			fmt.Println("检查请求数限制失败:", err.Error())
+			fmt.Println("检查成功请求数限制失败:", err.Error())
 			abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
 			return
 		}
 		if !allowed {
-			abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount))
+			abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到请求数限制:%d分钟内最多请求%d次", setting.ModelRequestRateLimitDurationMinutes, successMaxCount))
+			return
 		}
+		//检查总请求数限制并记录总请求(当totalMaxCount为0时会自动跳过,使用令牌桶限流器
+		totalKey := fmt.Sprintf("rateLimit:%s", userId)
+		//allowed, err = checkRedisRateLimit(ctx, rdb, totalKey, totalMaxCount, duration)
+		// 初始化
+		tb := limiter.New(ctx, rdb)
+		allowed, err = tb.Allow(
+			ctx,
+			totalKey,
+			limiter.WithCapacity(int64(totalMaxCount)*duration),
+			limiter.WithRate(int64(totalMaxCount)),
+			limiter.WithRequested(duration),
+		)
 
-		// 2. 检查成功请求数限制
-		successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId)
-		allowed, err = checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration)
 		if err != nil {
-			fmt.Println("检查成功请求数限制失败:", err.Error())
+			fmt.Println("检查总请求数限制失败:", err.Error())
 			abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
 			return
 		}
+
 		if !allowed {
-			abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到请求数限制:%d分钟内最多请求%d次", setting.ModelRequestRateLimitDurationMinutes, successMaxCount))
-			return
+			abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount))
 		}
 
-		// 3. 记录总请求(当totalMaxCount为0时会自动跳过)
-		recordRedisRequest(ctx, rdb, totalKey, totalMaxCount)
-
 		// 4. 处理请求
 		c.Next()