ソースを参照

feat(ledger): usage_ledger decoupling (#811)

* test(ledger): fix unit tests after ledger migration

* fix(ledger): address bugbot review comments

- ledger-fallback: log warning on catch instead of silently forcing ledger-only; preserve last known state
- usage-logs: use finalProviderId (not providerId) in findUsageLogsStats for consistency
- usage-logs: guard minRetryCount innerJoin with isLedgerOnlyMode() to prevent zero results after log cleanup
- cleanup-immunity test: fix toContain literal regex to toMatch with proper regex
- backfill: use pg_try_advisory_xact_lock (transaction-scoped) instead of session-scoped lock to be pool-safe
- trigger: handle warmup UPDATE edge case — update existing ledger row blocked_by when row transitions to warmup

* chore: format code (feat-usage-ledger-decoupling-c605ad6)

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Ding 1 ヶ月 前
コミット
8be7490b65

+ 2 - 0
drizzle/0073_magical_manta.sql

@@ -47,6 +47,8 @@ DECLARE
   v_is_success boolean;
 BEGIN
   IF NEW.blocked_by = 'warmup' THEN
+    -- If a ledger row already exists (row was originally non-warmup), mark it as warmup
+    UPDATE usage_ledger SET blocked_by = 'warmup' WHERE request_id = NEW.id;
     RETURN NEW;
   END IF;
 

+ 54 - 50
src/lib/ledger-backfill/service.ts

@@ -15,27 +15,30 @@ export async function backfillUsageLedger(): Promise<BackfillUsageLedgerSummary>
   const startTime = Date.now();
   const LOCK_KEY = 20260101;
 
-  const result = await db.execute(sql`
-    SELECT pg_try_advisory_lock(${LOCK_KEY}) AS acquired
-  `);
+  // Use pg_try_advisory_xact_lock (transaction-scoped) so lock/unlock always happen
+  // on the same connection — safe with connection pools.
+  return await db.transaction(async (tx) => {
+    const lockResult = await tx.execute(sql`
+      SELECT pg_try_advisory_xact_lock(${LOCK_KEY}) AS acquired
+    `);
 
-  const acquired = (result as unknown as Array<{ acquired: boolean }>)[0]?.acquired;
-  if (!acquired) {
-    return {
-      totalProcessed: 0,
-      totalInserted: 0,
-      durationMs: Date.now() - startTime,
-      alreadyExisted: 0,
-    };
-  }
+    const acquired = (lockResult as unknown as Array<{ acquired: boolean }>)[0]?.acquired;
+    if (!acquired) {
+      return {
+        totalProcessed: 0,
+        totalInserted: 0,
+        durationMs: Date.now() - startTime,
+        alreadyExisted: 0,
+      };
+    }
 
-  try {
-    let totalProcessed = 0;
-    let totalInserted = 0;
-    let lastId = 0;
+    try {
+      let totalProcessed = 0;
+      let totalInserted = 0;
+      let lastId = 0;
 
-    while (true) {
-      const batchResult = await db.execute(sql`
+      while (true) {
+        const batchResult = await tx.execute(sql`
         WITH batch AS (
           SELECT
             mr.id,
@@ -132,41 +135,42 @@ export async function backfillUsageLedger(): Promise<BackfillUsageLedgerSummary>
           COALESCE((SELECT MAX(id) FROM batch), 0)::integer AS max_id
       `);
 
-      const batchRow = (
-        batchResult as unknown as Array<{
-          processed?: number | string;
-          inserted?: number | string;
-          max_id?: number | string;
-        }>
-      )[0];
+        const batchRow = (
+          batchResult as unknown as Array<{
+            processed?: number | string;
+            inserted?: number | string;
+            max_id?: number | string;
+          }>
+        )[0];
 
-      const processed = Number(batchRow?.processed ?? 0);
-      const inserted = Number(batchRow?.inserted ?? 0);
-      const maxId = Number(batchRow?.max_id ?? 0);
+        const processed = Number(batchRow?.processed ?? 0);
+        const inserted = Number(batchRow?.inserted ?? 0);
+        const maxId = Number(batchRow?.max_id ?? 0);
 
-      if (processed === 0) {
-        break;
-      }
+        if (processed === 0) {
+          break;
+        }
 
-      totalProcessed += processed;
-      totalInserted += inserted;
-      lastId = maxId;
+        totalProcessed += processed;
+        totalInserted += inserted;
+        lastId = maxId;
 
-      logger.info("Backfill progress", {
-        processed: totalProcessed,
-        inserted: totalInserted,
-        elapsed: Date.now() - startTime,
-      });
-    }
+        logger.info("Backfill progress", {
+          processed: totalProcessed,
+          inserted: totalInserted,
+          elapsed: Date.now() - startTime,
+        });
+      }
 
-    const durationMs = Date.now() - startTime;
-    return {
-      totalProcessed,
-      totalInserted,
-      durationMs,
-      alreadyExisted: totalProcessed - totalInserted,
-    };
-  } finally {
-    await db.execute(sql`SELECT pg_advisory_unlock(${LOCK_KEY})`);
-  }
+      const durationMs = Date.now() - startTime;
+      return {
+        totalProcessed,
+        totalInserted,
+        durationMs,
+        alreadyExisted: totalProcessed - totalInserted,
+      };
+    } finally {
+      // pg_try_advisory_xact_lock is automatically released when the transaction ends
+    }
+  });
 }

+ 2 - 0
src/lib/ledger-backfill/trigger.sql

@@ -5,6 +5,8 @@ DECLARE
   v_is_success boolean;
 BEGIN
   IF NEW.blocked_by = 'warmup' THEN
+    -- If a ledger row already exists (row was originally non-warmup), mark it as warmup
+    UPDATE usage_ledger SET blocked_by = 'warmup' WHERE request_id = NEW.id;
     RETURN NEW;
   END IF;
 

+ 5 - 3
src/lib/ledger-fallback.ts

@@ -2,6 +2,7 @@ import "server-only";
 
 import { sql } from "drizzle-orm";
 import { db } from "@/drizzle/db";
+import { logger } from "@/lib/logger";
 
 let cachedResult: boolean | null = null;
 let cacheExpiry = 0;
@@ -18,9 +19,10 @@ export async function isLedgerOnlyMode(): Promise<boolean> {
     cachedResult = !hasData;
     cacheExpiry = now + 60_000;
     return cachedResult;
-  } catch {
-    cachedResult = true;
+  } catch (err) {
+    logger.warn("[ledger-fallback] Failed to check message_request existence", { error: err });
+    cachedResult = cachedResult ?? false;
     cacheExpiry = now + 60_000;
-    return true;
+    return cachedResult;
   }
 }

+ 4 - 2
src/repository/usage-logs.ts

@@ -975,7 +975,7 @@ export async function findUsageLogsStats(
   }
 
   if (providerId !== undefined) {
-    conditions.push(eq(usageLedger.providerId, providerId));
+    conditions.push(eq(usageLedger.finalProviderId, providerId));
   }
 
   const trimmedSessionId = filters.sessionId?.trim();
@@ -1029,8 +1029,10 @@ export async function findUsageLogsStats(
       ? baseQuery.innerJoin(keysTable, eq(usageLedger.key, keysTable.key))
       : baseQuery;
 
+  // In ledger-only mode, message_request is empty — skip the innerJoin to avoid zeroing all results
+  const ledgerOnly = await isLedgerOnlyMode();
   const query =
-    filters.minRetryCount !== undefined
+    filters.minRetryCount !== undefined && !ledgerOnly
       ? queryByKey.innerJoin(messageRequest, eq(usageLedger.requestId, messageRequest.id))
       : queryByKey;
 

+ 22 - 12
tests/unit/lib/endpoint-circuit-breaker.test.ts

@@ -275,13 +275,12 @@ describe("endpoint-circuit-breaker", () => {
   test("triggerEndpointCircuitBreakerAlert should call sendCircuitBreakerAlert", async () => {
     vi.resetModules();
 
-    const sendAlertMock = vi.fn(async () => {});
     vi.doMock("@/lib/config/env.schema", () => ({
       getEnvConfig: () => ({ ENABLE_ENDPOINT_CIRCUIT_BREAKER: true }),
     }));
     vi.doMock("@/lib/logger", () => ({ logger: createLoggerMock() }));
     vi.doMock("@/lib/notification/notifier", () => ({
-      sendCircuitBreakerAlert: sendAlertMock,
+      sendCircuitBreakerAlert: vi.fn(async () => {}),
     }));
     vi.doMock("@/repository", () => ({
       findProviderEndpointById: vi.fn(async () => null),
@@ -289,11 +288,14 @@ describe("endpoint-circuit-breaker", () => {
 
     // recordEndpointFailure 会 non-blocking 触发告警;先让 event-loop 跑完再清空计数,避免串台导致误判
     await flushPromises();
-    sendAlertMock.mockClear();
 
     // Prime module cache for dynamic import() consumers
     await import("@/lib/config/env.schema");
-    await import("@/lib/notification/notifier");
+    const notifierModule = await import("@/lib/notification/notifier");
+    const sendAlertSpy = vi
+      .spyOn(notifierModule, "sendCircuitBreakerAlert")
+      .mockResolvedValue(undefined);
+    sendAlertSpy.mockClear();
 
     const { triggerEndpointCircuitBreakerAlert } = await import("@/lib/endpoint-circuit-breaker");
 
@@ -304,8 +306,11 @@ describe("endpoint-circuit-breaker", () => {
       "connection refused"
     );
 
-    expect(sendAlertMock).toHaveBeenCalledTimes(1);
-    expect(sendAlertMock).toHaveBeenCalledWith({
+    const endpoint5Calls = sendAlertSpy.mock.calls
+      .map((call) => call[0] as Record<string, unknown>)
+      .filter((payload) => payload.endpointId === 5);
+    expect(endpoint5Calls).toHaveLength(1);
+    expect(endpoint5Calls[0]).toEqual({
       providerId: 0,
       providerName: "endpoint:5",
       failureCount: 3,
@@ -320,12 +325,11 @@ describe("endpoint-circuit-breaker", () => {
   test("triggerEndpointCircuitBreakerAlert should include endpointUrl when available", async () => {
     vi.resetModules();
 
-    const sendAlertMock = vi.fn(async () => {});
     vi.doMock("@/lib/config/env.schema", () => ({
       getEnvConfig: () => ({ ENABLE_ENDPOINT_CIRCUIT_BREAKER: true }),
     }));
     vi.doMock("@/lib/notification/notifier", () => ({
-      sendCircuitBreakerAlert: sendAlertMock,
+      sendCircuitBreakerAlert: vi.fn(async () => {}),
     }));
     vi.doMock("@/repository", () => ({
       findProviderEndpointById: vi.fn(async () => ({
@@ -350,18 +354,24 @@ describe("endpoint-circuit-breaker", () => {
 
     // recordEndpointFailure 会 non-blocking 触发告警;先让 event-loop 跑完再清空计数,避免串台导致误判
     await flushPromises();
-    sendAlertMock.mockClear();
 
     // Prime module cache for dynamic import() consumers
     await import("@/lib/config/env.schema");
-    await import("@/lib/notification/notifier");
+    const notifierModule = await import("@/lib/notification/notifier");
+    const sendAlertSpy = vi
+      .spyOn(notifierModule, "sendCircuitBreakerAlert")
+      .mockResolvedValue(undefined);
+    sendAlertSpy.mockClear();
 
     const { triggerEndpointCircuitBreakerAlert } = await import("@/lib/endpoint-circuit-breaker");
 
     await triggerEndpointCircuitBreakerAlert(10, 3, "2026-01-01T00:05:00.000Z", "timeout");
 
-    expect(sendAlertMock).toHaveBeenCalledTimes(1);
-    expect(sendAlertMock).toHaveBeenCalledWith({
+    const endpoint10Calls = sendAlertSpy.mock.calls
+      .map((call) => call[0] as Record<string, unknown>)
+      .filter((payload) => payload.endpointId === 10);
+    expect(endpoint10Calls).toHaveLength(1);
+    expect(endpoint10Calls[0]).toEqual({
       providerId: 1,
       providerName: "Custom Endpoint",
       failureCount: 3,

+ 18 - 0
tests/unit/repository/leaderboard-timezone-parentheses.test.ts

@@ -76,6 +76,24 @@ vi.mock("@/drizzle/db", () => {
 });
 
 vi.mock("@/drizzle/schema", () => ({
+  usageLedger: {
+    userId: "userId",
+    providerId: "providerId",
+    finalProviderId: "finalProviderId",
+    costUsd: "costUsd",
+    inputTokens: "inputTokens",
+    outputTokens: "outputTokens",
+    cacheCreationInputTokens: "cacheCreationInputTokens",
+    cacheReadInputTokens: "cacheReadInputTokens",
+    blockedBy: "blockedBy",
+    createdAt: "createdAt",
+    ttfbMs: "ttfbMs",
+    durationMs: "durationMs",
+    statusCode: "statusCode",
+    isSuccess: "isSuccess",
+    model: "model",
+    originalModel: "originalModel",
+  },
   messageRequest: {
     deletedAt: "deletedAt",
     providerId: "providerId",

+ 20 - 8
tests/unit/repository/usage-logs-sessionid-filter.test.ts

@@ -64,15 +64,21 @@ describe("Usage logs sessionId filter", () => {
         execute: vi.fn(async () => ({ count: 0 })),
       },
     }));
+    vi.doMock("@/lib/ledger-fallback", () => ({
+      isLedgerOnlyMode: vi.fn(async () => true),
+    }));
 
     const { findUsageLogsBatch } = await import("@/repository/usage-logs");
     await findUsageLogsBatch({});
     await findUsageLogsBatch({ sessionId: "   " });
 
-    expect(whereArgs).toHaveLength(2);
-    const baseWhereSql = sqlToString(whereArgs[0]).toLowerCase();
-    const blankWhereSql = sqlToString(whereArgs[1]).toLowerCase();
-    expect(blankWhereSql).toBe(baseWhereSql);
+    expect(whereArgs).toHaveLength(4);
+    const basePrimaryWhereSql = sqlToString(whereArgs[0]).toLowerCase();
+    const baseLedgerWhereSql = sqlToString(whereArgs[1]).toLowerCase();
+    const blankPrimaryWhereSql = sqlToString(whereArgs[2]).toLowerCase();
+    const blankLedgerWhereSql = sqlToString(whereArgs[3]).toLowerCase();
+    expect(blankPrimaryWhereSql).toBe(basePrimaryWhereSql);
+    expect(blankLedgerWhereSql).toBe(baseLedgerWhereSql);
   });
 
   test("findUsageLogsBatch: sessionId 应 trim 后精确匹配", async () => {
@@ -87,14 +93,20 @@ describe("Usage logs sessionId filter", () => {
         execute: vi.fn(async () => ({ count: 0 })),
       },
     }));
+    vi.doMock("@/lib/ledger-fallback", () => ({
+      isLedgerOnlyMode: vi.fn(async () => true),
+    }));
 
     const { findUsageLogsBatch } = await import("@/repository/usage-logs");
     await findUsageLogsBatch({ sessionId: "  abc  " });
 
-    expect(whereArgs.length).toBeGreaterThan(0);
-    const whereSql = sqlToString(whereArgs[0]).toLowerCase();
-    expect(whereSql).toContain("abc");
-    expect(whereSql).not.toContain("  abc  ");
+    expect(whereArgs).toHaveLength(2);
+    const primaryWhereSql = sqlToString(whereArgs[0]).toLowerCase();
+    const ledgerWhereSql = sqlToString(whereArgs[1]).toLowerCase();
+    expect(primaryWhereSql).toContain("abc");
+    expect(primaryWhereSql).not.toContain("  abc  ");
+    expect(ledgerWhereSql).toContain("abc");
+    expect(ledgerWhereSql).not.toContain("  abc  ");
   });
 
   test("findUsageLogsWithDetails: sessionId 为空/空白不应追加条件", async () => {

+ 1 - 1
tests/unit/usage-ledger/cleanup-immunity.test.ts

@@ -7,7 +7,7 @@ const usersTs = readFileSync(resolve(process.cwd(), "src/actions/users.ts"), "ut
 
 describe("usage_ledger cleanup immunity", () => {
   it("log cleanup service never imports or queries usageLedger", () => {
-    expect(serviceTs).not.toContain("import.*usageLedger");
+    expect(serviceTs).not.toMatch(/import\b.*\busageLedger\b/);
     expect(serviceTs).not.toMatch(/from.*schema.*usageLedger/);
     expect(serviceTs).not.toContain("db.delete(usageLedger)");
     expect(serviceTs).not.toContain('from("usage_ledger")');