1
0

rate-limit.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. package middleware
  2. import (
  3. "context"
  4. "fmt"
  5. "github.com/gin-gonic/gin"
  6. "message-pusher/common"
  7. "net/http"
  8. "time"
  9. )
  10. var timeFormat = "2006-01-02T15:04:05.000Z"
  11. var inMemoryRateLimiter common.InMemoryRateLimiter
  12. func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark string) {
  13. ctx := context.Background()
  14. rdb := common.RDB
  15. key := "rateLimit:" + mark + c.ClientIP()
  16. listLength, err := rdb.LLen(ctx, key).Result()
  17. if err != nil {
  18. fmt.Println(err.Error())
  19. c.Status(http.StatusInternalServerError)
  20. c.Abort()
  21. return
  22. }
  23. if listLength < int64(maxRequestNum) {
  24. rdb.LPush(ctx, key, time.Now().Format(timeFormat))
  25. rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
  26. } else {
  27. oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result()
  28. oldTime, err := time.Parse(timeFormat, oldTimeStr)
  29. if err != nil {
  30. fmt.Println(err)
  31. c.Status(http.StatusInternalServerError)
  32. c.Abort()
  33. return
  34. }
  35. nowTimeStr := time.Now().Format(timeFormat)
  36. nowTime, err := time.Parse(timeFormat, nowTimeStr)
  37. if err != nil {
  38. fmt.Println(err)
  39. c.Status(http.StatusInternalServerError)
  40. c.Abort()
  41. return
  42. }
  43. // time.Since will return negative number!
  44. // See: https://stackoverflow.com/questions/50970900/why-is-time-since-returning-negative-durations-on-windows
  45. if int64(nowTime.Sub(oldTime).Seconds()) < duration {
  46. rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
  47. c.Status(http.StatusTooManyRequests)
  48. c.Abort()
  49. return
  50. } else {
  51. rdb.LPush(ctx, key, time.Now().Format(timeFormat))
  52. rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1))
  53. rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
  54. }
  55. }
  56. }
  57. func memoryRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark string) {
  58. key := mark + c.ClientIP()
  59. if !inMemoryRateLimiter.Request(key, maxRequestNum, duration) {
  60. c.Status(http.StatusTooManyRequests)
  61. c.Abort()
  62. return
  63. }
  64. }
  65. func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gin.Context) {
  66. if common.RedisEnabled {
  67. return func(c *gin.Context) {
  68. redisRateLimiter(c, maxRequestNum, duration, mark)
  69. }
  70. } else {
  71. // It's safe to call multi times.
  72. inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration)
  73. return func(c *gin.Context) {
  74. memoryRateLimiter(c, maxRequestNum, duration, mark)
  75. }
  76. }
  77. }
  78. func GlobalWebRateLimit() func(c *gin.Context) {
  79. return rateLimitFactory(common.GlobalWebRateLimitNum, common.GlobalWebRateLimitDuration, "GW")
  80. }
  81. func GlobalAPIRateLimit() func(c *gin.Context) {
  82. return rateLimitFactory(common.GlobalApiRateLimitNum, common.GlobalApiRateLimitDuration, "GA")
  83. }
  84. func CriticalRateLimit() func(c *gin.Context) {
  85. return rateLimitFactory(common.CriticalRateLimitNum, common.CriticalRateLimitDuration, "CT")
  86. }
  87. func DownloadRateLimit() func(c *gin.Context) {
  88. return rateLimitFactory(common.DownloadRateLimitNum, common.DownloadRateLimitDuration, "DW")
  89. }
  90. func UploadRateLimit() func(c *gin.Context) {
  91. return rateLimitFactory(common.UploadRateLimitNum, common.UploadRateLimitDuration, "UP")
  92. }