Browse Source

refactor(repo): migrate key.ts + provider.ts read paths to usage_ledger

ding113 1 week ago
parent
commit
3437a95fec
2 changed files with 80 additions and 117 deletions
  1. 69 82
      src/repository/key.ts
  2. 11 35
      src/repository/provider.ts

+ 69 - 82
src/repository/key.ts

@@ -2,7 +2,7 @@
 
 import { and, count, desc, eq, gt, gte, inArray, isNull, lt, or, sql, sum } from "drizzle-orm";
 import { db } from "@/drizzle/db";
-import { keys, messageRequest, providers, users } from "@/drizzle/schema";
+import { keys, providers, usageLedger, users } from "@/drizzle/schema";
 import { CHANNEL_API_KEYS_UPDATED, publishCacheInvalidation } from "@/lib/redis/pubsub";
 import {
   cacheActiveKey,
@@ -16,7 +16,7 @@ import { apiKeyVacuumFilter } from "@/lib/security/api-key-vacuum-filter";
 import { Decimal, toCostDecimal } from "@/lib/utils/currency";
 import type { CreateKeyData, Key, UpdateKeyData } from "@/types/key";
 import type { User } from "@/types/user";
-import { EXCLUDE_WARMUP_CONDITION } from "./_shared/message-request-conditions";
+import { LEDGER_BILLING_CONDITION } from "./_shared/ledger-conditions";
 import { toKey, toUser } from "./_shared/transformers";
 
 export async function findKeyById(id: number): Promise<Key | null> {
@@ -337,17 +337,16 @@ export async function findKeyUsageToday(
   const rows = await db
     .select({
       keyId: keys.id,
-      totalCost: sum(messageRequest.costUsd),
+      totalCost: sum(usageLedger.costUsd),
     })
     .from(keys)
     .leftJoin(
-      messageRequest,
+      usageLedger,
       and(
-        eq(messageRequest.key, keys.key),
-        isNull(messageRequest.deletedAt),
-        EXCLUDE_WARMUP_CONDITION,
-        gte(messageRequest.createdAt, today),
-        lt(messageRequest.createdAt, tomorrow)
+        eq(usageLedger.key, keys.key),
+        LEDGER_BILLING_CONDITION,
+        gte(usageLedger.createdAt, today),
+        lt(usageLedger.createdAt, tomorrow)
       )
     )
     .where(and(eq(keys.userId, userId), isNull(keys.deletedAt)))
@@ -382,23 +381,22 @@ export async function findKeyUsageTodayBatch(
     .select({
       userId: keys.userId,
       keyId: keys.id,
-      totalCost: sum(messageRequest.costUsd),
+      totalCost: sum(usageLedger.costUsd),
       totalTokens: sql<number>`COALESCE(SUM(
-        COALESCE(${messageRequest.inputTokens}, 0)::double precision +
-        COALESCE(${messageRequest.outputTokens}, 0)::double precision +
-        COALESCE(${messageRequest.cacheCreationInputTokens}, 0)::double precision +
-        COALESCE(${messageRequest.cacheReadInputTokens}, 0)::double precision
+        COALESCE(${usageLedger.inputTokens}, 0)::double precision +
+        COALESCE(${usageLedger.outputTokens}, 0)::double precision +
+        COALESCE(${usageLedger.cacheCreationInputTokens}, 0)::double precision +
+        COALESCE(${usageLedger.cacheReadInputTokens}, 0)::double precision
       ), 0::double precision)`,
     })
     .from(keys)
     .leftJoin(
-      messageRequest,
+      usageLedger,
       and(
-        eq(messageRequest.key, keys.key),
-        isNull(messageRequest.deletedAt),
-        EXCLUDE_WARMUP_CONDITION,
-        gte(messageRequest.createdAt, today),
-        lt(messageRequest.createdAt, tomorrow)
+        eq(usageLedger.key, keys.key),
+        LEDGER_BILLING_CONDITION,
+        gte(usageLedger.createdAt, today),
+        lt(usageLedger.createdAt, tomorrow)
       )
     )
     .where(and(inArray(keys.userId, userIds), isNull(keys.deletedAt)))
@@ -725,58 +723,50 @@ export async function findKeysWithStatistics(userId: number): Promise<KeyStatist
     // 查询今日调用次数
     const [todayCount] = await db
       .select({ count: count() })
-      .from(messageRequest)
+      .from(usageLedger)
       .where(
         and(
-          eq(messageRequest.key, key.key),
-          isNull(messageRequest.deletedAt),
-          EXCLUDE_WARMUP_CONDITION,
-          gte(messageRequest.createdAt, today),
-          lt(messageRequest.createdAt, tomorrow)
+          eq(usageLedger.key, key.key),
+          LEDGER_BILLING_CONDITION,
+          gte(usageLedger.createdAt, today),
+          lt(usageLedger.createdAt, tomorrow)
         )
       );
 
     // 查询最后使用时间和供应商
     const [lastUsage] = await db
       .select({
-        createdAt: messageRequest.createdAt,
+        createdAt: usageLedger.createdAt,
         providerName: providers.name,
       })
-      .from(messageRequest)
-      .innerJoin(providers, eq(messageRequest.providerId, providers.id))
-      .where(
-        and(
-          eq(messageRequest.key, key.key),
-          isNull(messageRequest.deletedAt),
-          EXCLUDE_WARMUP_CONDITION
-        )
-      )
-      .orderBy(desc(messageRequest.createdAt))
+      .from(usageLedger)
+      .innerJoin(providers, eq(usageLedger.finalProviderId, providers.id))
+      .where(and(eq(usageLedger.key, key.key), LEDGER_BILLING_CONDITION))
+      .orderBy(desc(usageLedger.createdAt))
       .limit(1);
 
     // 查询分模型统计(仅统计当天)
     const modelStatsRows = await db
       .select({
-        model: messageRequest.model,
+        model: usageLedger.model,
         callCount: sql<number>`count(*)::int`,
-        totalCost: sum(messageRequest.costUsd),
-        inputTokens: sql<number>`COALESCE(sum(${messageRequest.inputTokens}), 0)::double precision`,
-        outputTokens: sql<number>`COALESCE(sum(${messageRequest.outputTokens}), 0)::double precision`,
-        cacheCreationTokens: sql<number>`COALESCE(sum(${messageRequest.cacheCreationInputTokens}), 0)::double precision`,
-        cacheReadTokens: sql<number>`COALESCE(sum(${messageRequest.cacheReadInputTokens}), 0)::double precision`,
+        totalCost: sum(usageLedger.costUsd),
+        inputTokens: sql<number>`COALESCE(sum(${usageLedger.inputTokens}), 0)::double precision`,
+        outputTokens: sql<number>`COALESCE(sum(${usageLedger.outputTokens}), 0)::double precision`,
+        cacheCreationTokens: sql<number>`COALESCE(sum(${usageLedger.cacheCreationInputTokens}), 0)::double precision`,
+        cacheReadTokens: sql<number>`COALESCE(sum(${usageLedger.cacheReadInputTokens}), 0)::double precision`,
       })
-      .from(messageRequest)
+      .from(usageLedger)
       .where(
         and(
-          eq(messageRequest.key, key.key),
-          isNull(messageRequest.deletedAt),
-          EXCLUDE_WARMUP_CONDITION,
-          gte(messageRequest.createdAt, today),
-          lt(messageRequest.createdAt, tomorrow),
-          sql`${messageRequest.model} IS NOT NULL`
+          eq(usageLedger.key, key.key),
+          LEDGER_BILLING_CONDITION,
+          gte(usageLedger.createdAt, today),
+          lt(usageLedger.createdAt, tomorrow),
+          sql`${usageLedger.model} IS NOT NULL`
         )
       )
-      .groupBy(messageRequest.model)
+      .groupBy(usageLedger.model)
       .orderBy(desc(sql`count(*)`));
 
     const modelStats = modelStatsRows.map((row) => ({
@@ -852,20 +842,19 @@ export async function findKeysWithStatisticsBatch(
   // Step 2: Query today's call counts for all keys at once
   const todayCountRows = await db
     .select({
-      key: messageRequest.key,
+      key: usageLedger.key,
       count: count(),
     })
-    .from(messageRequest)
+    .from(usageLedger)
     .where(
       and(
-        inArray(messageRequest.key, keyStrings),
-        isNull(messageRequest.deletedAt),
-        EXCLUDE_WARMUP_CONDITION,
-        gte(messageRequest.createdAt, today),
-        lt(messageRequest.createdAt, tomorrow)
+        inArray(usageLedger.key, keyStrings),
+        LEDGER_BILLING_CONDITION,
+        gte(usageLedger.createdAt, today),
+        lt(usageLedger.createdAt, tomorrow)
       )
     )
-    .groupBy(messageRequest.key);
+    .groupBy(usageLedger.key);
 
   const todayCountMap = new Map<string, number>();
   for (const row of todayCountRows) {
@@ -879,15 +868,14 @@ export async function findKeysWithStatisticsBatch(
     SELECT k.key_val AS key, lr.created_at, p.name AS provider_name
     FROM unnest(${keyStrings}::varchar[]) AS k(key_val)
     LEFT JOIN LATERAL (
-      SELECT mr.created_at, mr.provider_id
-      FROM message_request mr
-      WHERE mr.key = k.key_val
-        AND mr.deleted_at IS NULL
-        AND (mr.blocked_by IS NULL OR mr.blocked_by <> 'warmup')
-      ORDER BY mr.created_at DESC NULLS LAST
+      SELECT ul.created_at, ul.final_provider_id
+      FROM usage_ledger ul
+      WHERE ul.key = k.key_val
+        AND ul.blocked_by IS NULL
+      ORDER BY ul.created_at DESC NULLS LAST
       LIMIT 1
     ) lr ON true
-    LEFT JOIN providers p ON lr.provider_id = p.id
+    LEFT JOIN providers p ON lr.final_provider_id = p.id
   `);
 
   const lastUsageMap = new Map<string, { createdAt: Date | null; providerName: string | null }>();
@@ -907,28 +895,27 @@ export async function findKeysWithStatisticsBatch(
   // Step 4: Query model statistics for all keys at once
   const modelStatsRows = await db
     .select({
-      key: messageRequest.key,
-      model: messageRequest.model,
+      key: usageLedger.key,
+      model: usageLedger.model,
       callCount: sql<number>`count(*)::int`,
-      totalCost: sum(messageRequest.costUsd),
-      inputTokens: sql<number>`COALESCE(sum(${messageRequest.inputTokens}), 0)::double precision`,
-      outputTokens: sql<number>`COALESCE(sum(${messageRequest.outputTokens}), 0)::double precision`,
-      cacheCreationTokens: sql<number>`COALESCE(sum(${messageRequest.cacheCreationInputTokens}), 0)::double precision`,
-      cacheReadTokens: sql<number>`COALESCE(sum(${messageRequest.cacheReadInputTokens}), 0)::double precision`,
+      totalCost: sum(usageLedger.costUsd),
+      inputTokens: sql<number>`COALESCE(sum(${usageLedger.inputTokens}), 0)::double precision`,
+      outputTokens: sql<number>`COALESCE(sum(${usageLedger.outputTokens}), 0)::double precision`,
+      cacheCreationTokens: sql<number>`COALESCE(sum(${usageLedger.cacheCreationInputTokens}), 0)::double precision`,
+      cacheReadTokens: sql<number>`COALESCE(sum(${usageLedger.cacheReadInputTokens}), 0)::double precision`,
     })
-    .from(messageRequest)
+    .from(usageLedger)
     .where(
       and(
-        inArray(messageRequest.key, keyStrings),
-        isNull(messageRequest.deletedAt),
-        EXCLUDE_WARMUP_CONDITION,
-        gte(messageRequest.createdAt, today),
-        lt(messageRequest.createdAt, tomorrow),
-        sql`${messageRequest.model} IS NOT NULL`
+        inArray(usageLedger.key, keyStrings),
+        LEDGER_BILLING_CONDITION,
+        gte(usageLedger.createdAt, today),
+        lt(usageLedger.createdAt, tomorrow),
+        sql`${usageLedger.model} IS NOT NULL`
       )
     )
-    .groupBy(messageRequest.key, messageRequest.model)
-    .orderBy(messageRequest.key, desc(sql`count(*)`));
+    .groupBy(usageLedger.key, usageLedger.model)
+    .orderBy(usageLedger.key, desc(sql`count(*)`));
 
   // Group model stats by key
   const modelStatsMap = new Map<

+ 11 - 35
src/repository/provider.ts

@@ -1437,7 +1437,7 @@ export async function getDistinctProviderGroups(): Promise<string[]> {
  * 包括:今天的总金额、今天的调用次数、最近一次调用时间和模型
  *
  * 性能优化:
- * - provider_stats: 先按最终供应商聚合,再与 providers 做 LEFT JOIN,避免 providers × message_request 的笛卡尔积
+ * - provider_stats: 先按最终供应商聚合,再与 providers 做 LEFT JOIN,避免 providers × usage_ledger 的笛卡尔积
  * - bounds: 用“按时区计算的时间范围”过滤 created_at,便于命中 created_at 索引
  * - DST 兼容:对“本地日界/近 7 日”先在 timestamp 上做 +interval,再 AT TIME ZONE 回到 timestamptz,避免夏令时跨日偏移
  * - latest_call: 限制近 7 天范围,避免扫描历史数据
@@ -1483,8 +1483,6 @@ export async function getProviderStatistics(): Promise<ProviderStatisticsRow[]>
     }
 
     const promise: Promise<ProviderStatisticsRow[]> = (async () => {
-      // 使用 providerChain 最后一项的 providerId 来确定最终供应商(兼容重试切换)
-      // 如果 provider_chain 为空(无重试),则使用 provider_id 字段
       const query = sql`
          WITH bounds AS (
            SELECT
@@ -1495,45 +1493,23 @@ export async function getProviderStatistics(): Promise<ProviderStatisticsRow[]>
          provider_stats AS (
            -- 先按最终供应商聚合,再与 providers 做 LEFT JOIN,避免 providers × 今日请求 的笛卡尔积
            SELECT
-            mr.final_provider_id,
-            COALESCE(SUM(mr.cost_usd), 0) AS today_cost,
+            final_provider_id,
+            COALESCE(SUM(cost_usd), 0) AS today_cost,
             COUNT(*)::integer AS today_calls
-          FROM (
-            SELECT
-              CASE
-                WHEN provider_chain IS NULL OR provider_chain = '[]'::jsonb THEN provider_id
-                WHEN (provider_chain->-1->>'id') ~ '^[0-9]+$' THEN (provider_chain->-1->>'id')::int
-                ELSE provider_id
-              END AS final_provider_id,
-              cost_usd
-            FROM message_request
-            WHERE deleted_at IS NULL
-              AND (blocked_by IS NULL OR blocked_by <> 'warmup')
-              AND created_at >= (SELECT today_start FROM bounds)
-              AND created_at < (SELECT tomorrow_start FROM bounds)
-          ) mr
-          GROUP BY mr.final_provider_id
+          FROM usage_ledger
+          WHERE blocked_by IS NULL
+            AND created_at >= (SELECT today_start FROM bounds)
+            AND created_at < (SELECT tomorrow_start FROM bounds)
+          GROUP BY final_provider_id
         ),
         latest_call AS (
           SELECT DISTINCT ON (final_provider_id)
             final_provider_id,
             created_at AS last_call_time,
             model AS last_call_model
-          FROM (
-            SELECT
-              CASE
-                WHEN provider_chain IS NULL OR provider_chain = '[]'::jsonb THEN provider_id
-                WHEN (provider_chain->-1->>'id') ~ '^[0-9]+$' THEN (provider_chain->-1->>'id')::int
-                ELSE provider_id
-              END AS final_provider_id,
-              id,
-              created_at,
-              model
-            FROM message_request
-            WHERE deleted_at IS NULL
-              AND (blocked_by IS NULL OR blocked_by <> 'warmup')
-              AND created_at >= (SELECT last7_start FROM bounds)
-          ) mr
+          FROM usage_ledger
+          WHERE blocked_by IS NULL
+            AND created_at >= (SELECT last7_start FROM bounds)
           -- 性能优化:添加 7 天时间范围限制(避免扫描历史数据)
           ORDER BY final_provider_id, created_at DESC, id DESC
         )