model-rate-limit.go 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. package middleware
  2. import (
  3. "context"
  4. "fmt"
  5. "net/http"
  6. "one-api/common"
  7. "one-api/common/limiter"
  8. "one-api/constant"
  9. "one-api/setting"
  10. "strconv"
  11. "time"
  12. "github.com/gin-gonic/gin"
  13. "github.com/go-redis/redis/v8"
  14. )
  15. const (
  16. ModelRequestRateLimitCountMark = "MRRL"
  17. ModelRequestRateLimitSuccessCountMark = "MRRLS"
  18. )
  19. // 检查Redis中的请求限制
  20. func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, maxCount int, duration int64) (bool, error) {
  21. // 如果maxCount为0,表示不限制
  22. if maxCount == 0 {
  23. return true, nil
  24. }
  25. // 获取当前计数
  26. length, err := rdb.LLen(ctx, key).Result()
  27. if err != nil {
  28. return false, err
  29. }
  30. // 如果未达到限制,允许请求
  31. if length < int64(maxCount) {
  32. return true, nil
  33. }
  34. // 检查时间窗口
  35. oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result()
  36. oldTime, err := time.Parse(timeFormat, oldTimeStr)
  37. if err != nil {
  38. return false, err
  39. }
  40. nowTimeStr := time.Now().Format(timeFormat)
  41. nowTime, err := time.Parse(timeFormat, nowTimeStr)
  42. if err != nil {
  43. return false, err
  44. }
  45. // 如果在时间窗口内已达到限制,拒绝请求
  46. subTime := nowTime.Sub(oldTime).Seconds()
  47. if int64(subTime) < duration {
  48. rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute)
  49. return false, nil
  50. }
  51. return true, nil
  52. }
  53. // 记录Redis请求
  54. func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxCount int) {
  55. // 如果maxCount为0,不记录请求
  56. if maxCount == 0 {
  57. return
  58. }
  59. now := time.Now().Format(timeFormat)
  60. rdb.LPush(ctx, key, now)
  61. rdb.LTrim(ctx, key, 0, int64(maxCount-1))
  62. rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute)
  63. }
  64. // Redis限流处理器
  65. func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc {
  66. return func(c *gin.Context) {
  67. userId := strconv.Itoa(c.GetInt("id"))
  68. ctx := context.Background()
  69. rdb := common.RDB
  70. // 1. 检查成功请求数限制
  71. successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId)
  72. allowed, err := checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration)
  73. if err != nil {
  74. fmt.Println("检查成功请求数限制失败:", err.Error())
  75. abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
  76. return
  77. }
  78. if !allowed {
  79. abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到请求数限制:%d分钟内最多请求%d次", setting.ModelRequestRateLimitDurationMinutes, successMaxCount))
  80. return
  81. }
  82. //2.检查总请求数限制并记录总请求(当totalMaxCount为0时会自动跳过,使用令牌桶限流器
  83. if totalMaxCount > 0 {
  84. totalKey := fmt.Sprintf("rateLimit:%s", userId)
  85. // 初始化
  86. tb := limiter.New(ctx, rdb)
  87. allowed, err = tb.Allow(
  88. ctx,
  89. totalKey,
  90. limiter.WithCapacity(int64(totalMaxCount)*duration),
  91. limiter.WithRate(int64(totalMaxCount)),
  92. limiter.WithRequested(duration),
  93. )
  94. if err != nil {
  95. fmt.Println("检查总请求数限制失败:", err.Error())
  96. abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
  97. return
  98. }
  99. if !allowed {
  100. abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount))
  101. }
  102. }
  103. // 4. 处理请求
  104. c.Next()
  105. // 5. 如果请求成功,记录成功请求
  106. if c.Writer.Status() < 400 {
  107. recordRedisRequest(ctx, rdb, successKey, successMaxCount)
  108. }
  109. }
  110. }
  111. // 内存限流处理器
  112. func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc {
  113. inMemoryRateLimiter.Init(time.Duration(setting.ModelRequestRateLimitDurationMinutes) * time.Minute)
  114. return func(c *gin.Context) {
  115. userId := strconv.Itoa(c.GetInt("id"))
  116. totalKey := ModelRequestRateLimitCountMark + userId
  117. successKey := ModelRequestRateLimitSuccessCountMark + userId
  118. // 1. 检查总请求数限制(当totalMaxCount为0时跳过)
  119. if totalMaxCount > 0 && !inMemoryRateLimiter.Request(totalKey, totalMaxCount, duration) {
  120. c.Status(http.StatusTooManyRequests)
  121. c.Abort()
  122. return
  123. }
  124. // 2. 检查成功请求数限制
  125. // 使用一个临时key来检查限制,这样可以避免实际记录
  126. checkKey := successKey + "_check"
  127. if !inMemoryRateLimiter.Request(checkKey, successMaxCount, duration) {
  128. c.Status(http.StatusTooManyRequests)
  129. c.Abort()
  130. return
  131. }
  132. // 3. 处理请求
  133. c.Next()
  134. // 4. 如果请求成功,记录到实际的成功请求计数中
  135. if c.Writer.Status() < 400 {
  136. inMemoryRateLimiter.Request(successKey, successMaxCount, duration)
  137. }
  138. }
  139. }
  140. // ModelRequestRateLimit 模型请求限流中间件
  141. func ModelRequestRateLimit() func(c *gin.Context) {
  142. return func(c *gin.Context) {
  143. // 在每个请求时检查是否启用限流
  144. if !setting.ModelRequestRateLimitEnabled {
  145. c.Next()
  146. return
  147. }
  148. // 计算限流参数
  149. duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60)
  150. totalMaxCount := setting.ModelRequestRateLimitCount
  151. successMaxCount := setting.ModelRequestRateLimitSuccessCount
  152. // 获取分组
  153. group := common.GetContextKeyString(c, constant.ContextKeyTokenGroup)
  154. if group == "" {
  155. group = common.GetContextKeyString(c, constant.ContextKeyUserGroup)
  156. }
  157. //获取分组的限流配置
  158. groupTotalCount, groupSuccessCount, found := setting.GetGroupRateLimit(group)
  159. if found {
  160. totalMaxCount = groupTotalCount
  161. successMaxCount = groupSuccessCount
  162. }
  163. // 根据存储类型选择并执行限流处理器
  164. if common.RedisEnabled {
  165. redisRateLimitHandler(duration, totalMaxCount, successMaxCount)(c)
  166. } else {
  167. memoryRateLimitHandler(duration, totalMaxCount, successMaxCount)(c)
  168. }
  169. }
  170. }