| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889 |
- package limiter
- import (
- "context"
- _ "embed"
- "fmt"
- "github.com/go-redis/redis/v8"
- "one-api/common"
- "sync"
- )
- //go:embed lua/rate_limit.lua
- var rateLimitScript string
- type RedisLimiter struct {
- client *redis.Client
- limitScriptSHA string
- }
- var (
- instance *RedisLimiter
- once sync.Once
- )
- func New(ctx context.Context, r *redis.Client) *RedisLimiter {
- once.Do(func() {
- // 预加载脚本
- limitSHA, err := r.ScriptLoad(ctx, rateLimitScript).Result()
- if err != nil {
- common.SysLog(fmt.Sprintf("Failed to load rate limit script: %v", err))
- }
- instance = &RedisLimiter{
- client: r,
- limitScriptSHA: limitSHA,
- }
- })
- return instance
- }
- func (rl *RedisLimiter) Allow(ctx context.Context, key string, opts ...Option) (bool, error) {
- // 默认配置
- config := &Config{
- Capacity: 10,
- Rate: 1,
- Requested: 1,
- }
- // 应用选项模式
- for _, opt := range opts {
- opt(config)
- }
- // 执行限流
- result, err := rl.client.EvalSha(
- ctx,
- rl.limitScriptSHA,
- []string{key},
- config.Requested,
- config.Rate,
- config.Capacity,
- ).Int()
- if err != nil {
- return false, fmt.Errorf("rate limit failed: %w", err)
- }
- return result == 1, nil
- }
- // Config 配置选项模式
- type Config struct {
- Capacity int64
- Rate int64
- Requested int64
- }
- type Option func(*Config)
- func WithCapacity(c int64) Option {
- return func(cfg *Config) { cfg.Capacity = c }
- }
- func WithRate(r int64) Option {
- return func(cfg *Config) { cfg.Rate = r }
- }
- func WithRequested(n int64) Option {
- return func(cfg *Config) { cfg.Requested = n }
- }
|