redis.go 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. package reqlimit
  2. import (
  3. "context"
  4. "errors"
  5. "strconv"
  6. "strings"
  7. "time"
  8. "github.com/labring/aiproxy/core/common"
  9. "github.com/redis/go-redis/v9"
  10. )
  11. type redisRateRecord struct {
  12. prefix string
  13. }
  14. func newRedisGroupModelRecord() *redisRateRecord {
  15. return &redisRateRecord{
  16. prefix: "group-model-record",
  17. }
  18. }
  19. func newRedisGroupModelTokennameRecord() *redisRateRecord {
  20. return &redisRateRecord{
  21. prefix: "group-model-tokenname-record",
  22. }
  23. }
  24. func newRedisChannelModelRecord() *redisRateRecord {
  25. return &redisRateRecord{
  26. prefix: "channel-model-record",
  27. }
  28. }
  29. func newRedisGroupModelTokensRecord() *redisRateRecord {
  30. return &redisRateRecord{
  31. prefix: "group-model-tokens-record",
  32. }
  33. }
  34. func newRedisGroupModelTokennameTokensRecord() *redisRateRecord {
  35. return &redisRateRecord{
  36. prefix: "group-model-tokenname-tokens-record",
  37. }
  38. }
  39. func newRedisChannelModelTokensRecord() *redisRateRecord {
  40. return &redisRateRecord{
  41. prefix: "channel-model-tokens-record",
  42. }
  43. }
  44. const pushRequestLuaScript = `
  45. local key = KEYS[1]
  46. local window_seconds = tonumber(ARGV[1])
  47. local current_time = tonumber(ARGV[2])
  48. local max_requests = tonumber(ARGV[3])
  49. local n = tonumber(ARGV[4])
  50. local cutoff_slice = current_time - window_seconds
  51. local function parse_count(value)
  52. if not value then return 0, 0 end
  53. local r, e = value:match("^(%d+):(%d+)$")
  54. return tonumber(r) or 0, tonumber(e) or 0
  55. end
  56. local count = 0
  57. local over_count = 0
  58. local all_fields = redis.call('HGETALL', key)
  59. for i = 1, #all_fields, 2 do
  60. local field_slice = tonumber(all_fields[i])
  61. if field_slice < cutoff_slice then
  62. redis.call('HDEL', key, all_fields[i])
  63. else
  64. local c, oc = parse_count(all_fields[i+1])
  65. count = count + c
  66. over_count = over_count + oc
  67. end
  68. end
  69. local current_value = redis.call('HGET', key, tostring(current_time))
  70. local current_c, current_oc = parse_count(current_value)
  71. if max_requests == 0 or count <= max_requests then
  72. current_c = current_c + n
  73. count = count + n
  74. else
  75. current_oc = current_oc + n
  76. over_count = over_count + n
  77. end
  78. redis.call('HSET', key, current_time, current_c .. ":" .. current_oc)
  79. redis.call('EXPIRE', key, window_seconds)
  80. local current_second_count = current_c + current_oc
  81. return string.format("%d:%d:%d", count, over_count, current_second_count)
  82. `
  83. const getRequestCountLuaScript = `
  84. local pattern = KEYS[1]
  85. local window_seconds = tonumber(ARGV[1])
  86. local current_time = tonumber(ARGV[2])
  87. local cutoff_slice = current_time - window_seconds
  88. local function parse_count(value)
  89. if not value then return 0, 0 end
  90. local r, e = value:match("^(%d+):(%d+)$")
  91. return tonumber(r) or 0, tonumber(e) or 0
  92. end
  93. local total = 0
  94. local current_second_count = 0
  95. local keys = redis.call('KEYS', pattern)
  96. for _, key in ipairs(keys) do
  97. local count = 0
  98. local over = 0
  99. local all_fields = redis.call('HGETALL', key)
  100. for i=1, #all_fields, 2 do
  101. local field_slice = tonumber(all_fields[i])
  102. if field_slice < cutoff_slice then
  103. redis.call('HDEL', key, all_fields[i])
  104. else
  105. local c, oc = parse_count(all_fields[i+1])
  106. count = count + c
  107. over = over + oc
  108. if field_slice == current_time then
  109. current_second_count = current_second_count + c + oc
  110. end
  111. end
  112. end
  113. total = total + count + over
  114. end
  115. return string.format("%d:%d", total, current_second_count)
  116. `
  117. var (
  118. pushRequestScript = redis.NewScript(pushRequestLuaScript)
  119. getRequestCountScript = redis.NewScript(getRequestCountLuaScript)
  120. )
  121. func (r *redisRateRecord) buildKey(keys ...string) string {
  122. return common.RedisKey(r.prefix + ":" + strings.Join(keys, ":"))
  123. }
  124. func (r *redisRateRecord) GetRequest(
  125. ctx context.Context,
  126. duration time.Duration,
  127. keys ...string,
  128. ) (totalCount, secondCount int64, err error) {
  129. if !common.RedisEnabled {
  130. return 0, 0, nil
  131. }
  132. pattern := r.buildKey(keys...)
  133. result, err := getRequestCountScript.Run(
  134. ctx,
  135. common.RDB,
  136. []string{pattern},
  137. duration.Seconds(),
  138. time.Now().Unix(),
  139. ).Text()
  140. if err != nil {
  141. return 0, 0, err
  142. }
  143. parts := strings.Split(result, ":")
  144. if len(parts) != 2 {
  145. return 0, 0, errors.New("invalid result format")
  146. }
  147. totalCountInt, err := strconv.ParseInt(parts[0], 10, 64)
  148. if err != nil {
  149. return 0, 0, err
  150. }
  151. secondCountInt, err := strconv.ParseInt(parts[1], 10, 64)
  152. if err != nil {
  153. return 0, 0, err
  154. }
  155. return totalCountInt, secondCountInt, nil
  156. }
  157. func (r *redisRateRecord) PushRequest(
  158. ctx context.Context,
  159. overed int64,
  160. duration time.Duration,
  161. n int64,
  162. keys ...string,
  163. ) (normalCount, overCount, secondCount int64, err error) {
  164. key := r.buildKey(keys...)
  165. result, err := pushRequestScript.Run(
  166. ctx,
  167. common.RDB,
  168. []string{key},
  169. duration.Seconds(),
  170. time.Now().Unix(),
  171. overed,
  172. n,
  173. ).Text()
  174. if err != nil {
  175. return 0, 0, 0, err
  176. }
  177. parts := strings.Split(result, ":")
  178. if len(parts) != 3 {
  179. return 0, 0, 0, errors.New("invalid result")
  180. }
  181. countInt, err := strconv.ParseInt(parts[0], 10, 64)
  182. if err != nil {
  183. return 0, 0, 0, err
  184. }
  185. overLimitCountInt, err := strconv.ParseInt(parts[1], 10, 64)
  186. if err != nil {
  187. return 0, 0, 0, err
  188. }
  189. secondCountInt, err := strconv.ParseInt(parts[2], 10, 64)
  190. if err != nil {
  191. return 0, 0, 0, err
  192. }
  193. return countInt, overLimitCountInt, secondCountInt, nil
  194. }