| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199 |
- package middleware
- import (
- "context"
- "fmt"
- "net/http"
- "one-api/common"
- "one-api/common/limiter"
- "one-api/constant"
- "one-api/setting"
- "strconv"
- "time"
- "github.com/gin-gonic/gin"
- "github.com/go-redis/redis/v8"
- )
- const (
- ModelRequestRateLimitCountMark = "MRRL"
- ModelRequestRateLimitSuccessCountMark = "MRRLS"
- )
- // 检查Redis中的请求限制
- func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, maxCount int, duration int64) (bool, error) {
- // 如果maxCount为0,表示不限制
- if maxCount == 0 {
- return true, nil
- }
- // 获取当前计数
- length, err := rdb.LLen(ctx, key).Result()
- if err != nil {
- return false, err
- }
- // 如果未达到限制,允许请求
- if length < int64(maxCount) {
- return true, nil
- }
- // 检查时间窗口
- oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result()
- oldTime, err := time.Parse(timeFormat, oldTimeStr)
- if err != nil {
- return false, err
- }
- nowTimeStr := time.Now().Format(timeFormat)
- nowTime, err := time.Parse(timeFormat, nowTimeStr)
- if err != nil {
- return false, err
- }
- // 如果在时间窗口内已达到限制,拒绝请求
- subTime := nowTime.Sub(oldTime).Seconds()
- if int64(subTime) < duration {
- rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute)
- return false, nil
- }
- return true, nil
- }
- // 记录Redis请求
- func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxCount int) {
- // 如果maxCount为0,不记录请求
- if maxCount == 0 {
- return
- }
- now := time.Now().Format(timeFormat)
- rdb.LPush(ctx, key, now)
- rdb.LTrim(ctx, key, 0, int64(maxCount-1))
- rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute)
- }
- // Redis限流处理器
- func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc {
- return func(c *gin.Context) {
- userId := strconv.Itoa(c.GetInt("id"))
- ctx := context.Background()
- rdb := common.RDB
- // 1. 检查成功请求数限制
- successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId)
- allowed, err := checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration)
- if err != nil {
- fmt.Println("检查成功请求数限制失败:", err.Error())
- abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
- return
- }
- if !allowed {
- abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到请求数限制:%d分钟内最多请求%d次", setting.ModelRequestRateLimitDurationMinutes, successMaxCount))
- return
- }
- //2.检查总请求数限制并记录总请求(当totalMaxCount为0时会自动跳过,使用令牌桶限流器
- if totalMaxCount > 0 {
- totalKey := fmt.Sprintf("rateLimit:%s", userId)
- // 初始化
- tb := limiter.New(ctx, rdb)
- allowed, err = tb.Allow(
- ctx,
- totalKey,
- limiter.WithCapacity(int64(totalMaxCount)*duration),
- limiter.WithRate(int64(totalMaxCount)),
- limiter.WithRequested(duration),
- )
- if err != nil {
- fmt.Println("检查总请求数限制失败:", err.Error())
- abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
- return
- }
- if !allowed {
- abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount))
- }
- }
- // 4. 处理请求
- c.Next()
- // 5. 如果请求成功,记录成功请求
- if c.Writer.Status() < 400 {
- recordRedisRequest(ctx, rdb, successKey, successMaxCount)
- }
- }
- }
- // 内存限流处理器
- func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc {
- inMemoryRateLimiter.Init(time.Duration(setting.ModelRequestRateLimitDurationMinutes) * time.Minute)
- return func(c *gin.Context) {
- userId := strconv.Itoa(c.GetInt("id"))
- totalKey := ModelRequestRateLimitCountMark + userId
- successKey := ModelRequestRateLimitSuccessCountMark + userId
- // 1. 检查总请求数限制(当totalMaxCount为0时跳过)
- if totalMaxCount > 0 && !inMemoryRateLimiter.Request(totalKey, totalMaxCount, duration) {
- c.Status(http.StatusTooManyRequests)
- c.Abort()
- return
- }
- // 2. 检查成功请求数限制
- // 使用一个临时key来检查限制,这样可以避免实际记录
- checkKey := successKey + "_check"
- if !inMemoryRateLimiter.Request(checkKey, successMaxCount, duration) {
- c.Status(http.StatusTooManyRequests)
- c.Abort()
- return
- }
- // 3. 处理请求
- c.Next()
- // 4. 如果请求成功,记录到实际的成功请求计数中
- if c.Writer.Status() < 400 {
- inMemoryRateLimiter.Request(successKey, successMaxCount, duration)
- }
- }
- }
- // ModelRequestRateLimit 模型请求限流中间件
- func ModelRequestRateLimit() func(c *gin.Context) {
- return func(c *gin.Context) {
- // 在每个请求时检查是否启用限流
- if !setting.ModelRequestRateLimitEnabled {
- c.Next()
- return
- }
- // 计算限流参数
- duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60)
- totalMaxCount := setting.ModelRequestRateLimitCount
- successMaxCount := setting.ModelRequestRateLimitSuccessCount
- // 获取分组
- group := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
- if group == "" {
- group = common.GetContextKeyString(c, constant.ContextKeyUserGroup)
- }
- //获取分组的限流配置
- groupTotalCount, groupSuccessCount, found := setting.GetGroupRateLimit(group)
- if found {
- totalMaxCount = groupTotalCount
- successMaxCount = groupSuccessCount
- }
- // 根据存储类型选择并执行限流处理器
- if common.RedisEnabled {
- redisRateLimitHandler(duration, totalMaxCount, successMaxCount)(c)
- } else {
- memoryRateLimitHandler(duration, totalMaxCount, successMaxCount)(c)
- }
- }
- }
|