limiter.go 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. package limiter
  2. import (
  3. "context"
  4. _ "embed"
  5. "fmt"
  6. "github.com/go-redis/redis/v8"
  7. "one-api/common"
  8. "sync"
  9. )
  10. //go:embed lua/rate_limit.lua
  11. var rateLimitScript string
  12. type RedisLimiter struct {
  13. client *redis.Client
  14. limitScriptSHA string
  15. }
  16. var (
  17. instance *RedisLimiter
  18. once sync.Once
  19. )
  20. func New(ctx context.Context, r *redis.Client) *RedisLimiter {
  21. once.Do(func() {
  22. // 预加载脚本
  23. limitSHA, err := r.ScriptLoad(ctx, rateLimitScript).Result()
  24. if err != nil {
  25. common.SysLog(fmt.Sprintf("Failed to load rate limit script: %v", err))
  26. }
  27. instance = &RedisLimiter{
  28. client: r,
  29. limitScriptSHA: limitSHA,
  30. }
  31. })
  32. return instance
  33. }
  34. func (rl *RedisLimiter) Allow(ctx context.Context, key string, opts ...Option) (bool, error) {
  35. // 默认配置
  36. config := &Config{
  37. Capacity: 10,
  38. Rate: 1,
  39. Requested: 1,
  40. }
  41. // 应用选项模式
  42. for _, opt := range opts {
  43. opt(config)
  44. }
  45. // 执行限流
  46. result, err := rl.client.EvalSha(
  47. ctx,
  48. rl.limitScriptSHA,
  49. []string{key},
  50. config.Requested,
  51. config.Rate,
  52. config.Capacity,
  53. ).Int()
  54. if err != nil {
  55. return false, fmt.Errorf("rate limit failed: %w", err)
  56. }
  57. return result == 1, nil
  58. }
  59. // Config 配置选项模式
  60. type Config struct {
  61. Capacity int64
  62. Rate int64
  63. Requested int64
  64. }
  65. type Option func(*Config)
  66. func WithCapacity(c int64) Option {
  67. return func(cfg *Config) { cfg.Capacity = c }
  68. }
  69. func WithRate(r int64) Option {
  70. return func(cfg *Config) { cfg.Rate = r }
  71. }
  72. func WithRequested(n int64) Option {
  73. return func(cfg *Config) { cfg.Requested = n }
  74. }