Quellcode durchsuchen

feat(middleware): redis atomic incr, show waiting time

creamlike1024 vor 6 Monaten
Ursprung
Commit
6ea19b0ae2
1 geänderte Dateien mit 18 neuen und 8 gelöschten Zeilen
  1. 18 8
      middleware/email-verification-rate-limit.go

+ 18 - 8
middleware/email-verification-rate-limit.go

@@ -21,24 +21,34 @@ func redisEmailVerificationRateLimiter(c *gin.Context) {
 	rdb := common.RDB
 	key := "emailVerification:" + EmailVerificationRateLimitMark + ":" + c.ClientIP()
 
-	listLength, err := rdb.LLen(ctx, key).Result()
+	count, err := rdb.Incr(ctx, key).Result()
 	if err != nil {
-		fmt.Println("Redis限流检查失败:", err.Error())
-		c.Status(http.StatusInternalServerError)
-		c.Abort()
+		// fallback
+		memoryEmailVerificationRateLimiter(c)
 		return
 	}
 
-	if listLength < EmailVerificationMaxRequests {
-		rdb.LPush(ctx, key, time.Now().Format(timeFormat))
-		rdb.Expire(ctx, key, time.Duration(EmailVerificationDuration)*time.Second)
+	// 第一次设置键时设置过期时间
+	if count == 1 {
+		_ = rdb.Expire(ctx, key, time.Duration(EmailVerificationDuration)*time.Second).Err()
+	}
+
+	// 检查是否超出限制
+	if count <= int64(EmailVerificationMaxRequests) {
 		c.Next()
 		return
 	}
 
+	// 获取剩余等待时间
+	ttl, err := rdb.TTL(ctx, key).Result()
+	waitSeconds := int64(EmailVerificationDuration)
+	if err == nil && ttl > 0 {
+		waitSeconds = int64(ttl.Seconds())
+	}
+
 	c.JSON(http.StatusTooManyRequests, gin.H{
 		"success": false,
-		"message": fmt.Sprintf("发送过于频繁,请等待 %d 秒后再试", EmailVerificationDuration),
+		"message": fmt.Sprintf("发送过于频繁,请等待 %d 秒后再试", waitSeconds),
 	})
 	c.Abort()
 }