|
|
@@ -5,6 +5,7 @@ import (
|
|
|
"fmt"
|
|
|
"net/http"
|
|
|
"one-api/common"
|
|
|
+ "one-api/common/limiter"
|
|
|
"one-api/setting"
|
|
|
"strconv"
|
|
|
"time"
|
|
|
@@ -78,34 +79,41 @@ func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) g
|
|
|
ctx := context.Background()
|
|
|
rdb := common.RDB
|
|
|
|
|
|
- // 1. 检查总请求数限制(当totalMaxCount为0时会自动跳过)
|
|
|
- totalKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitCountMark, userId)
|
|
|
- allowed, err := checkRedisRateLimit(ctx, rdb, totalKey, totalMaxCount, duration)
|
|
|
+ // 1. 检查成功请求数限制
|
|
|
+ successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId)
|
|
|
+ allowed, err := checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration)
|
|
|
if err != nil {
|
|
|
- fmt.Println("检查总请求数限制失败:", err.Error())
|
|
|
+ fmt.Println("检查成功请求数限制失败:", err.Error())
|
|
|
abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
|
|
|
return
|
|
|
}
|
|
|
if !allowed {
|
|
|
- abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount))
|
|
|
+ abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到请求数限制:%d分钟内最多请求%d次", setting.ModelRequestRateLimitDurationMinutes, successMaxCount))
|
|
|
+ return
|
|
|
}
|
|
|
+ //检查总请求数限制并记录总请求(当totalMaxCount为0时会自动跳过,使用令牌桶限流器
|
|
|
+ totalKey := fmt.Sprintf("rateLimit:%s", userId)
|
|
|
+ //allowed, err = checkRedisRateLimit(ctx, rdb, totalKey, totalMaxCount, duration)
|
|
|
+ // 初始化
|
|
|
+ tb := limiter.New(ctx, rdb)
|
|
|
+ allowed, err = tb.Allow(
|
|
|
+ ctx,
|
|
|
+ totalKey,
|
|
|
+ limiter.WithCapacity(int64(totalMaxCount)*duration),
|
|
|
+ limiter.WithRate(int64(totalMaxCount)),
|
|
|
+ limiter.WithRequested(duration),
|
|
|
+ )
|
|
|
|
|
|
- // 2. 检查成功请求数限制
|
|
|
- successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId)
|
|
|
- allowed, err = checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration)
|
|
|
if err != nil {
|
|
|
- fmt.Println("检查成功请求数限制失败:", err.Error())
|
|
|
+ fmt.Println("检查总请求数限制失败:", err.Error())
|
|
|
abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
|
|
|
return
|
|
|
}
|
|
|
+
|
|
|
if !allowed {
|
|
|
- abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到请求数限制:%d分钟内最多请求%d次", setting.ModelRequestRateLimitDurationMinutes, successMaxCount))
|
|
|
- return
|
|
|
+ abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount))
|
|
|
}
|
|
|
|
|
|
- // 3. 记录总请求(当totalMaxCount为0时会自动跳过)
|
|
|
- recordRedisRequest(ctx, rdb, totalKey, totalMaxCount)
|
|
|
-
|
|
|
// 4. 处理请求
|
|
|
c.Next()
|
|
|
|