|
|
@@ -5,14 +5,26 @@ import {
|
|
|
CHECK_AND_TRACK_SESSION,
|
|
|
TRACK_COST_5H_ROLLING_WINDOW,
|
|
|
GET_COST_5H_ROLLING_WINDOW,
|
|
|
+ TRACK_COST_DAILY_ROLLING_WINDOW,
|
|
|
+ GET_COST_DAILY_ROLLING_WINDOW,
|
|
|
} from "@/lib/redis/lua-scripts";
|
|
|
import { sumUserCostToday } from "@/repository/statistics";
|
|
|
-import { getTimeRangeForPeriod, getTTLForPeriod, getSecondsUntilMidnight } from "./time-utils";
|
|
|
+import {
|
|
|
+ getTimeRangeForPeriod,
|
|
|
+ getTimeRangeForPeriodWithMode,
|
|
|
+ getTTLForPeriod,
|
|
|
+ getTTLForPeriodWithMode,
|
|
|
+ getSecondsUntilMidnight,
|
|
|
+ normalizeResetTime,
|
|
|
+ type DailyResetMode,
|
|
|
+} from "./time-utils";
|
|
|
|
|
|
interface CostLimit {
|
|
|
amount: number | null;
|
|
|
- period: "5h" | "weekly" | "monthly";
|
|
|
+ period: "5h" | "daily" | "weekly" | "monthly";
|
|
|
name: string;
|
|
|
+ resetTime?: string; // 自定义重置时间(仅 daily + fixed 模式使用,格式 "HH:mm")
|
|
|
+ resetMode?: DailyResetMode; // 日限额重置模式(仅 daily 使用)
|
|
|
}
|
|
|
|
|
|
export class RateLimitService {
|
|
|
@@ -21,6 +33,11 @@ export class RateLimitService {
|
|
|
return getRedisClient();
|
|
|
}
|
|
|
|
|
|
+ private static resolveDailyReset(resetTime?: string): { normalized: string; suffix: string } {
|
|
|
+ const normalized = normalizeResetTime(resetTime);
|
|
|
+ return { normalized, suffix: normalized.replace(":", "") };
|
|
|
+ }
|
|
|
+
|
|
|
/**
|
|
|
* 检查金额限制(Key 或 Provider)
|
|
|
* 优先使用 Redis,失败时降级到数据库查询(防止 Redis 清空后超支)
|
|
|
@@ -30,12 +47,24 @@ export class RateLimitService {
|
|
|
type: "key" | "provider",
|
|
|
limits: {
|
|
|
limit_5h_usd: number | null;
|
|
|
+ limit_daily_usd: number | null;
|
|
|
+ daily_reset_time?: string;
|
|
|
+ daily_reset_mode?: DailyResetMode;
|
|
|
limit_weekly_usd: number | null;
|
|
|
limit_monthly_usd: number | null;
|
|
|
}
|
|
|
): Promise<{ allowed: boolean; reason?: string }> {
|
|
|
+ const normalizedDailyReset = normalizeResetTime(limits.daily_reset_time);
|
|
|
+ const dailyResetMode = limits.daily_reset_mode ?? "fixed";
|
|
|
const costLimits: CostLimit[] = [
|
|
|
{ amount: limits.limit_5h_usd, period: "5h", name: "5小时" },
|
|
|
+ {
|
|
|
+ amount: limits.limit_daily_usd,
|
|
|
+ period: "daily",
|
|
|
+ name: "每日",
|
|
|
+ resetTime: normalizedDailyReset,
|
|
|
+ resetMode: dailyResetMode,
|
|
|
+ },
|
|
|
{ amount: limits.limit_weekly_usd, period: "weekly", name: "周" },
|
|
|
{ amount: limits.limit_monthly_usd, period: "monthly", name: "月" },
|
|
|
];
|
|
|
@@ -82,14 +111,48 @@ export class RateLimitService {
|
|
|
);
|
|
|
return await this.checkCostLimitsFromDatabase(id, type, costLimits);
|
|
|
}
|
|
|
+ } else if (limit.period === "daily" && limit.resetMode === "rolling") {
|
|
|
+ // daily 滚动窗口:使用 ZSET + Lua 脚本
|
|
|
+ try {
|
|
|
+ const key = `${type}:${id}:cost_daily_rolling`;
|
|
|
+ const window24h = 24 * 60 * 60 * 1000;
|
|
|
+ const result = (await this.redis.eval(
|
|
|
+ GET_COST_DAILY_ROLLING_WINDOW,
|
|
|
+ 1,
|
|
|
+ key,
|
|
|
+ now.toString(),
|
|
|
+ window24h.toString()
|
|
|
+ )) as string;
|
|
|
+
|
|
|
+ current = parseFloat(result || "0");
|
|
|
+
|
|
|
+ // Cache Miss 检测
|
|
|
+ if (current === 0) {
|
|
|
+ const exists = await this.redis.exists(key);
|
|
|
+ if (!exists) {
|
|
|
+ logger.info(
|
|
|
+ `[RateLimit] Cache miss for ${type}:${id}:cost_daily_rolling, querying database`
|
|
|
+ );
|
|
|
+ return await this.checkCostLimitsFromDatabase(id, type, costLimits);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ } catch (error) {
|
|
|
+ logger.error(
|
|
|
+ "[RateLimit] Daily rolling window query failed, fallback to database:",
|
|
|
+ error
|
|
|
+ );
|
|
|
+ return await this.checkCostLimitsFromDatabase(id, type, costLimits);
|
|
|
+ }
|
|
|
} else {
|
|
|
- // 周/月使用普通 GET
|
|
|
- const value = await this.redis.get(`${type}:${id}:cost_${limit.period}`);
|
|
|
+ // daily fixed/周/月使用普通 GET
|
|
|
+ const { suffix } = this.resolveDailyReset(limit.resetTime);
|
|
|
+ const periodKey = limit.period === "daily" ? `${limit.period}_${suffix}` : limit.period;
|
|
|
+ const value = await this.redis.get(`${type}:${id}:cost_${periodKey}`);
|
|
|
|
|
|
// Cache Miss 检测
|
|
|
if (value === null && limit.amount > 0) {
|
|
|
logger.info(
|
|
|
- `[RateLimit] Cache miss for ${type}:${id}:cost_${limit.period}, querying database`
|
|
|
+ `[RateLimit] Cache miss for ${type}:${id}:cost_${periodKey}, querying database`
|
|
|
);
|
|
|
return await this.checkCostLimitsFromDatabase(id, type, costLimits);
|
|
|
}
|
|
|
@@ -132,8 +195,12 @@ export class RateLimitService {
|
|
|
for (const limit of costLimits) {
|
|
|
if (!limit.amount || limit.amount <= 0) continue;
|
|
|
|
|
|
- // 计算时间范围(使用新的时间工具函数)
|
|
|
- const { startTime, endTime } = getTimeRangeForPeriod(limit.period);
|
|
|
+ // 计算时间范围(使用支持模式的时间工具函数)
|
|
|
+ const { startTime, endTime } = getTimeRangeForPeriodWithMode(
|
|
|
+ limit.period,
|
|
|
+ limit.resetTime,
|
|
|
+ limit.resetMode
|
|
|
+ );
|
|
|
|
|
|
// 查询数据库
|
|
|
const current =
|
|
|
@@ -162,17 +229,34 @@ export class RateLimitService {
|
|
|
|
|
|
logger.info(`[RateLimit] Cache warmed for ${key}, value=${current} (rolling window)`);
|
|
|
}
|
|
|
+ } else if (limit.period === "daily" && limit.resetMode === "rolling") {
|
|
|
+ // daily 滚动窗口:使用 ZSET + Lua 脚本
|
|
|
+ if (current > 0) {
|
|
|
+ const now = Date.now();
|
|
|
+ const window24h = 24 * 60 * 60 * 1000;
|
|
|
+ const key = `${type}:${id}:cost_daily_rolling`;
|
|
|
+
|
|
|
+ await this.redis.eval(
|
|
|
+ TRACK_COST_DAILY_ROLLING_WINDOW,
|
|
|
+ 1,
|
|
|
+ key,
|
|
|
+ current.toString(),
|
|
|
+ now.toString(),
|
|
|
+ window24h.toString()
|
|
|
+ );
|
|
|
+
|
|
|
+ logger.info(
|
|
|
+ `[RateLimit] Cache warmed for ${key}, value=${current} (daily rolling window)`
|
|
|
+ );
|
|
|
+ }
|
|
|
} else {
|
|
|
- // 周/月固定窗口:使用 STRING + 动态 TTL
|
|
|
- const ttl = getTTLForPeriod(limit.period);
|
|
|
- await this.redis.set(
|
|
|
- `${type}:${id}:cost_${limit.period}`,
|
|
|
- current.toString(),
|
|
|
- "EX",
|
|
|
- ttl
|
|
|
- );
|
|
|
+ // daily fixed/周/月固定窗口:使用 STRING + 动态 TTL
|
|
|
+ const { normalized, suffix } = this.resolveDailyReset(limit.resetTime);
|
|
|
+ const ttl = getTTLForPeriodWithMode(limit.period, normalized, limit.resetMode);
|
|
|
+ const periodKey = limit.period === "daily" ? `${limit.period}_${suffix}` : limit.period;
|
|
|
+ await this.redis.set(`${type}:${id}:cost_${periodKey}`, current.toString(), "EX", ttl);
|
|
|
logger.info(
|
|
|
- `[RateLimit] Cache warmed for ${type}:${id}:cost_${limit.period}, value=${current}, ttl=${ttl}s`
|
|
|
+ `[RateLimit] Cache warmed for ${type}:${id}:cost_${periodKey}, value=${current}, ttl=${ttl}s`
|
|
|
);
|
|
|
}
|
|
|
} catch (error) {
|
|
|
@@ -289,21 +373,38 @@ export class RateLimitService {
|
|
|
|
|
|
/**
|
|
|
* 累加消费(请求结束后调用)
|
|
|
- * 5h 使用滚动窗口(ZSET),周/月使用固定窗口(STRING)
|
|
|
+ * 5h 使用滚动窗口(ZSET),daily 根据模式选择滚动/固定窗口,周/月使用固定窗口(STRING)
|
|
|
*/
|
|
|
static async trackCost(
|
|
|
keyId: number,
|
|
|
providerId: number,
|
|
|
sessionId: string,
|
|
|
- cost: number
|
|
|
+ cost: number,
|
|
|
+ options?: {
|
|
|
+ keyResetTime?: string;
|
|
|
+ keyResetMode?: DailyResetMode;
|
|
|
+ providerResetTime?: string;
|
|
|
+ providerResetMode?: DailyResetMode;
|
|
|
+ }
|
|
|
): Promise<void> {
|
|
|
if (!this.redis || cost <= 0) return;
|
|
|
|
|
|
try {
|
|
|
+ const keyDailyReset = this.resolveDailyReset(options?.keyResetTime);
|
|
|
+ const providerDailyReset = this.resolveDailyReset(options?.providerResetTime);
|
|
|
+ const keyDailyMode = options?.keyResetMode ?? "fixed";
|
|
|
+ const providerDailyMode = options?.providerResetMode ?? "fixed";
|
|
|
const now = Date.now();
|
|
|
const window5h = 5 * 60 * 60 * 1000; // 5 hours in ms
|
|
|
-
|
|
|
- // 计算动态 TTL(周/月)
|
|
|
+ const window24h = 24 * 60 * 60 * 1000; // 24 hours in ms
|
|
|
+
|
|
|
+ // 计算动态 TTL(daily/周/月)
|
|
|
+ const ttlDailyKey = getTTLForPeriodWithMode("daily", keyDailyReset.normalized, keyDailyMode);
|
|
|
+ const ttlDailyProvider =
|
|
|
+ keyDailyReset.normalized === providerDailyReset.normalized &&
|
|
|
+ keyDailyMode === providerDailyMode
|
|
|
+ ? ttlDailyKey
|
|
|
+ : getTTLForPeriodWithMode("daily", providerDailyReset.normalized, providerDailyMode);
|
|
|
const ttlWeekly = getTTLForPeriod("weekly");
|
|
|
const ttlMonthly = getTTLForPeriod("monthly");
|
|
|
|
|
|
@@ -328,17 +429,52 @@ export class RateLimitService {
|
|
|
window5h.toString()
|
|
|
);
|
|
|
|
|
|
- // 2. 周/月固定窗口:使用 STRING + 动态 TTL
|
|
|
+ // 2. daily 滚动窗口:使用 Lua 脚本(ZSET)
|
|
|
+ if (keyDailyMode === "rolling") {
|
|
|
+ await this.redis.eval(
|
|
|
+ TRACK_COST_DAILY_ROLLING_WINDOW,
|
|
|
+ 1,
|
|
|
+ `key:${keyId}:cost_daily_rolling`,
|
|
|
+ cost.toString(),
|
|
|
+ now.toString(),
|
|
|
+ window24h.toString()
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
+ if (providerDailyMode === "rolling") {
|
|
|
+ await this.redis.eval(
|
|
|
+ TRACK_COST_DAILY_ROLLING_WINDOW,
|
|
|
+ 1,
|
|
|
+ `provider:${providerId}:cost_daily_rolling`,
|
|
|
+ cost.toString(),
|
|
|
+ now.toString(),
|
|
|
+ window24h.toString()
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
+ // 3. daily fixed/周/月固定窗口:使用 STRING + 动态 TTL
|
|
|
const pipeline = this.redis.pipeline();
|
|
|
|
|
|
- // Key 的周/月消费
|
|
|
+ // Key 的 daily fixed/周/月消费
|
|
|
+ if (keyDailyMode === "fixed") {
|
|
|
+ const keyDailyKey = `key:${keyId}:cost_daily_${keyDailyReset.suffix}`;
|
|
|
+ pipeline.incrbyfloat(keyDailyKey, cost);
|
|
|
+ pipeline.expire(keyDailyKey, ttlDailyKey);
|
|
|
+ }
|
|
|
+
|
|
|
pipeline.incrbyfloat(`key:${keyId}:cost_weekly`, cost);
|
|
|
pipeline.expire(`key:${keyId}:cost_weekly`, ttlWeekly);
|
|
|
|
|
|
pipeline.incrbyfloat(`key:${keyId}:cost_monthly`, cost);
|
|
|
pipeline.expire(`key:${keyId}:cost_monthly`, ttlMonthly);
|
|
|
|
|
|
- // Provider 的周/月消费
|
|
|
+ // Provider 的 daily fixed/周/月消费
|
|
|
+ if (providerDailyMode === "fixed") {
|
|
|
+ const providerDailyKey = `provider:${providerId}:cost_daily_${providerDailyReset.suffix}`;
|
|
|
+ pipeline.incrbyfloat(providerDailyKey, cost);
|
|
|
+ pipeline.expire(providerDailyKey, ttlDailyProvider);
|
|
|
+ }
|
|
|
+
|
|
|
pipeline.incrbyfloat(`provider:${providerId}:cost_weekly`, cost);
|
|
|
pipeline.expire(`provider:${providerId}:cost_weekly`, ttlWeekly);
|
|
|
|
|
|
@@ -361,9 +497,12 @@ export class RateLimitService {
|
|
|
static async getCurrentCost(
|
|
|
id: number,
|
|
|
type: "key" | "provider",
|
|
|
- period: "5h" | "weekly" | "monthly"
|
|
|
+ period: "5h" | "daily" | "weekly" | "monthly",
|
|
|
+ resetTime = "00:00",
|
|
|
+ resetMode: DailyResetMode = "fixed"
|
|
|
): Promise<number> {
|
|
|
try {
|
|
|
+ const dailyResetInfo = this.resolveDailyReset(resetTime);
|
|
|
// Fast Path: Redis 查询
|
|
|
if (this.redis && this.redis.status === "ready") {
|
|
|
let current = 0;
|
|
|
@@ -397,9 +536,40 @@ export class RateLimitService {
|
|
|
// Key 存在但值为 0,说明真的是 0
|
|
|
return 0;
|
|
|
}
|
|
|
+ } else if (period === "daily" && resetMode === "rolling") {
|
|
|
+ // daily 滚动窗口:使用 ZSET + Lua 脚本
|
|
|
+ const now = Date.now();
|
|
|
+ const window24h = 24 * 60 * 60 * 1000;
|
|
|
+ const key = `${type}:${id}:cost_daily_rolling`;
|
|
|
+
|
|
|
+ const result = (await this.redis.eval(
|
|
|
+ GET_COST_DAILY_ROLLING_WINDOW,
|
|
|
+ 1,
|
|
|
+ key,
|
|
|
+ now.toString(),
|
|
|
+ window24h.toString()
|
|
|
+ )) as string;
|
|
|
+
|
|
|
+ current = parseFloat(result || "0");
|
|
|
+
|
|
|
+ // Cache Hit
|
|
|
+ if (current > 0) {
|
|
|
+ return current;
|
|
|
+ }
|
|
|
+
|
|
|
+ // Cache Miss 检测
|
|
|
+ const exists = await this.redis.exists(key);
|
|
|
+ if (!exists) {
|
|
|
+ logger.info(
|
|
|
+ `[RateLimit] Cache miss for ${type}:${id}:cost_daily_rolling, querying database`
|
|
|
+ );
|
|
|
+ } else {
|
|
|
+ return 0;
|
|
|
+ }
|
|
|
} else {
|
|
|
- // 周/月使用普通 GET
|
|
|
- const value = await this.redis.get(`${type}:${id}:cost_${period}`);
|
|
|
+ // daily fixed/周/月使用普通 GET
|
|
|
+ const redisKey = period === "daily" ? `${period}_${dailyResetInfo.suffix}` : period;
|
|
|
+ const value = await this.redis.get(`${type}:${id}:cost_${redisKey}`);
|
|
|
|
|
|
// Cache Hit
|
|
|
if (value !== null) {
|
|
|
@@ -407,7 +577,9 @@ export class RateLimitService {
|
|
|
}
|
|
|
|
|
|
// Cache Miss: 从数据库恢复
|
|
|
- logger.info(`[RateLimit] Cache miss for ${type}:${id}:cost_${period}, querying database`);
|
|
|
+ logger.info(
|
|
|
+ `[RateLimit] Cache miss for ${type}:${id}:cost_${redisKey}, querying database`
|
|
|
+ );
|
|
|
}
|
|
|
} else {
|
|
|
logger.warn(`[RateLimit] Redis unavailable, querying database for ${type} cost`);
|
|
|
@@ -418,7 +590,11 @@ export class RateLimitService {
|
|
|
"@/repository/statistics"
|
|
|
);
|
|
|
|
|
|
- const { startTime, endTime } = getTimeRangeForPeriod(period);
|
|
|
+ const { startTime, endTime } = getTimeRangeForPeriodWithMode(
|
|
|
+ period,
|
|
|
+ dailyResetInfo.normalized,
|
|
|
+ resetMode
|
|
|
+ );
|
|
|
const current =
|
|
|
type === "key"
|
|
|
? await sumKeyCostInTimeRange(id, startTime, endTime)
|
|
|
@@ -447,12 +623,33 @@ export class RateLimitService {
|
|
|
|
|
|
logger.info(`[RateLimit] Cache warmed for ${key}, value=${current} (rolling window)`);
|
|
|
}
|
|
|
+ } else if (period === "daily" && resetMode === "rolling") {
|
|
|
+ // daily 滚动窗口:使用 ZSET + Lua 脚本
|
|
|
+ if (current > 0) {
|
|
|
+ const now = Date.now();
|
|
|
+ const window24h = 24 * 60 * 60 * 1000;
|
|
|
+ const key = `${type}:${id}:cost_daily_rolling`;
|
|
|
+
|
|
|
+ await this.redis.eval(
|
|
|
+ TRACK_COST_DAILY_ROLLING_WINDOW,
|
|
|
+ 1,
|
|
|
+ key,
|
|
|
+ current.toString(),
|
|
|
+ now.toString(),
|
|
|
+ window24h.toString()
|
|
|
+ );
|
|
|
+
|
|
|
+ logger.info(
|
|
|
+ `[RateLimit] Cache warmed for ${key}, value=${current} (daily rolling window)`
|
|
|
+ );
|
|
|
+ }
|
|
|
} else {
|
|
|
- // 周/月固定窗口:使用 STRING + 动态 TTL
|
|
|
- const ttl = getTTLForPeriod(period);
|
|
|
- await this.redis.set(`${type}:${id}:cost_${period}`, current.toString(), "EX", ttl);
|
|
|
+ // daily fixed/周/月固定窗口:使用 STRING + 动态 TTL
|
|
|
+ const redisKey = period === "daily" ? `${period}_${dailyResetInfo.suffix}` : period;
|
|
|
+ const ttl = getTTLForPeriodWithMode(period, dailyResetInfo.normalized, resetMode);
|
|
|
+ await this.redis.set(`${type}:${id}:cost_${redisKey}`, current.toString(), "EX", ttl);
|
|
|
logger.info(
|
|
|
- `[RateLimit] Cache warmed for ${type}:${id}:cost_${period}, value=${current}, ttl=${ttl}s`
|
|
|
+ `[RateLimit] Cache warmed for ${type}:${id}:cost_${redisKey}, value=${current}, ttl=${ttl}s`
|
|
|
);
|
|
|
}
|
|
|
} catch (error) {
|