|
|
@@ -17,36 +17,52 @@ export function createRateLimiter(
|
|
|
const dict = i18n(localeFromRequest(request))
|
|
|
|
|
|
const limits = Subscription.getFreeLimits()
|
|
|
- const limitValue =
|
|
|
- limits.checkHeader && !request.headers.get(limits.checkHeader)
|
|
|
- ? limits.fallbackValue
|
|
|
- : (rateLimit ?? limits.dailyRequests)
|
|
|
+ const headerExists = request.headers.has(limits.checkHeader)
|
|
|
+ const dailyLimit = !headerExists ? limits.fallbackValue : (rateLimit ?? limits.dailyRequests)
|
|
|
+ const isDefaultModel = headerExists && !rateLimit
|
|
|
|
|
|
const ip = !rawIp.length ? "unknown" : rawIp
|
|
|
const now = Date.now()
|
|
|
- const interval = rateLimit ? `${buildYYYYMMDD(now)}${modelId.substring(0, 2)}` : buildYYYYMMDD(now)
|
|
|
+ const lifetimeInterval = ""
|
|
|
+ const dailyInterval = rateLimit ? `${buildYYYYMMDD(now)}${modelId.substring(0, 2)}` : buildYYYYMMDD(now)
|
|
|
+
|
|
|
+ let _isNew: boolean
|
|
|
|
|
|
return {
|
|
|
- track: async () => {
|
|
|
- await Database.use((tx) =>
|
|
|
- tx
|
|
|
- .insert(IpRateLimitTable)
|
|
|
- .values({ ip, interval, count: 1 })
|
|
|
- .onDuplicateKeyUpdate({ set: { count: sql`${IpRateLimitTable.count} + 1` } }),
|
|
|
- )
|
|
|
- },
|
|
|
check: async () => {
|
|
|
const rows = await Database.use((tx) =>
|
|
|
tx
|
|
|
.select({ interval: IpRateLimitTable.interval, count: IpRateLimitTable.count })
|
|
|
.from(IpRateLimitTable)
|
|
|
- .where(and(eq(IpRateLimitTable.ip, ip), inArray(IpRateLimitTable.interval, [interval]))),
|
|
|
+ .where(
|
|
|
+ and(
|
|
|
+ eq(IpRateLimitTable.ip, ip),
|
|
|
+ isDefaultModel
|
|
|
+ ? inArray(IpRateLimitTable.interval, [lifetimeInterval, dailyInterval])
|
|
|
+ : inArray(IpRateLimitTable.interval, [dailyInterval]),
|
|
|
+ ),
|
|
|
+ ),
|
|
|
)
|
|
|
- const total = rows.reduce((sum, r) => sum + r.count, 0)
|
|
|
- logger.debug(`rate limit total: ${total}`)
|
|
|
- if (total >= limitValue)
|
|
|
+ const lifetimeCount = rows.find((r) => r.interval === lifetimeInterval)?.count ?? 0
|
|
|
+ const dailyCount = rows.find((r) => r.interval === dailyInterval)?.count ?? 0
|
|
|
+ logger.debug(`rate limit lifetime: ${lifetimeCount}, daily: ${dailyCount}`)
|
|
|
+
|
|
|
+ _isNew = isDefaultModel && lifetimeCount < dailyLimit * 7
|
|
|
+
|
|
|
+ if ((_isNew && dailyCount >= dailyLimit * 2) || (!_isNew && dailyCount >= dailyLimit))
|
|
|
throw new FreeUsageLimitError(dict["zen.api.error.rateLimitExceeded"], getRetryAfterDay(now))
|
|
|
},
|
|
|
+ track: async () => {
|
|
|
+ await Database.use((tx) =>
|
|
|
+ tx
|
|
|
+ .insert(IpRateLimitTable)
|
|
|
+ .values([
|
|
|
+ { ip, interval: dailyInterval, count: 1 },
|
|
|
+ ...(_isNew ? [{ ip, interval: lifetimeInterval, count: 1 }] : []),
|
|
|
+ ])
|
|
|
+ .onDuplicateKeyUpdate({ set: { count: sql`${IpRateLimitTable.count} + 1` } }),
|
|
|
+ )
|
|
|
+ },
|
|
|
}
|
|
|
}
|
|
|
|