Explorar el Código

feat(actions): add getSessionOriginChain server action

ding113 hace 2 semanas
padre
commit
7e9c3fbd33

+ 57 - 0
src/actions/session-origin-chain.ts

@@ -0,0 +1,57 @@
+"use server";
+
+import { and, eq, inArray, isNull, or } from "drizzle-orm";
+import { db } from "@/drizzle/db";
+import { messageRequest } from "@/drizzle/schema";
+import { getSession } from "@/lib/auth";
+import { logger } from "@/lib/logger";
+import { findKeyList } from "@/repository/key";
+import { findSessionOriginChain } from "@/repository/message";
+import type { ProviderChainItem } from "@/types/message";
+import type { ActionResult } from "./types";
+
+export async function getSessionOriginChain(
+  sessionId: string
+): Promise<ActionResult<ProviderChainItem[] | null>> {
+  try {
+    const session = await getSession();
+    if (!session) {
+      return { ok: false, error: "未登录" };
+    }
+
+    if (session.user.role !== "admin") {
+      const userKeys = await findKeyList(session.user.id);
+      const userKeyValues = userKeys.map((key) => key.key);
+
+      const ownershipCondition =
+        userKeyValues.length > 0
+          ? or(
+              eq(messageRequest.userId, session.user.id),
+              inArray(messageRequest.key, userKeyValues)
+            )
+          : eq(messageRequest.userId, session.user.id);
+
+      const [ownedSession] = await db
+        .select({ id: messageRequest.id })
+        .from(messageRequest)
+        .where(
+          and(
+            eq(messageRequest.sessionId, sessionId),
+            isNull(messageRequest.deletedAt),
+            ownershipCondition
+          )
+        )
+        .limit(1);
+
+      if (!ownedSession) {
+        return { ok: false, error: "无权访问该 Session" };
+      }
+    }
+
+    const chain = await findSessionOriginChain(sessionId);
+    return { ok: true, data: chain ?? null };
+  } catch (error) {
+    logger.error("获取会话来源链失败:", error);
+    return { ok: false, error: "获取会话来源链失败" };
+  }
+}

+ 109 - 0
tests/unit/actions/session-origin-chain.test.ts

@@ -0,0 +1,109 @@
+import { beforeEach, describe, expect, test, vi } from "vitest";
+import type { ProviderChainItem } from "@/types/message";
+
+const getSessionMock = vi.fn();
+const findSessionOriginChainMock = vi.fn();
+const findKeyListMock = vi.fn();
+
+const dbSelectMock = vi.fn();
+const dbFromMock = vi.fn();
+const dbWhereMock = vi.fn();
+const dbLimitMock = vi.fn();
+
+vi.mock("@/lib/auth", () => ({
+  getSession: getSessionMock,
+}));
+
+vi.mock("@/repository/message", () => ({
+  findSessionOriginChain: findSessionOriginChainMock,
+}));
+
+vi.mock("@/repository/key", () => ({
+  findKeyList: findKeyListMock,
+}));
+
+vi.mock("@/drizzle/db", () => ({
+  db: {
+    select: dbSelectMock,
+  },
+}));
+
+describe("getSessionOriginChain", () => {
+  beforeEach(() => {
+    vi.clearAllMocks();
+
+    dbSelectMock.mockReturnValue({ from: dbFromMock });
+    dbFromMock.mockReturnValue({ where: dbWhereMock });
+    dbWhereMock.mockReturnValue({ limit: dbLimitMock });
+    dbLimitMock.mockResolvedValue([{ id: 1 }]);
+
+    findKeyListMock.mockResolvedValue([{ key: "user-key-1" }]);
+  });
+
+  test("admin happy path: returns provider chain", async () => {
+    getSessionMock.mockResolvedValue({ user: { id: 1, role: "admin" } });
+
+    const chain: ProviderChainItem[] = [
+      {
+        id: 11,
+        name: "provider-a",
+        reason: "initial_selection",
+      },
+    ];
+    findSessionOriginChainMock.mockResolvedValue(chain);
+
+    const { getSessionOriginChain } = await import("@/actions/session-origin-chain");
+    const result = await getSessionOriginChain("sess-admin");
+
+    expect(result).toEqual({ ok: true, data: chain });
+    expect(findSessionOriginChainMock).toHaveBeenCalledWith("sess-admin");
+    expect(findKeyListMock).not.toHaveBeenCalled();
+    expect(dbSelectMock).not.toHaveBeenCalled();
+  });
+
+  test("non-admin happy path: returns provider chain after ownership check", async () => {
+    getSessionMock.mockResolvedValue({ user: { id: 2, role: "user" } });
+
+    const chain: ProviderChainItem[] = [
+      {
+        id: 22,
+        name: "provider-b",
+        reason: "session_reuse",
+      },
+    ];
+    findSessionOriginChainMock.mockResolvedValue(chain);
+
+    const { getSessionOriginChain } = await import("@/actions/session-origin-chain");
+    const result = await getSessionOriginChain("sess-user");
+
+    expect(result).toEqual({ ok: true, data: chain });
+    expect(findKeyListMock).toHaveBeenCalledWith(2);
+    expect(dbSelectMock).toHaveBeenCalledTimes(1);
+    expect(findSessionOriginChainMock).toHaveBeenCalledWith("sess-user");
+  });
+
+  test("unauthenticated: returns not logged in", async () => {
+    getSessionMock.mockResolvedValue(null);
+
+    const { getSessionOriginChain } = await import("@/actions/session-origin-chain");
+    const result = await getSessionOriginChain("sess-no-auth");
+
+    expect(result).toEqual({ ok: false, error: "未登录" });
+    expect(findSessionOriginChainMock).not.toHaveBeenCalled();
+    expect(findKeyListMock).not.toHaveBeenCalled();
+    expect(dbSelectMock).not.toHaveBeenCalled();
+  });
+
+  test("not found: returns ok with null data", async () => {
+    getSessionMock.mockResolvedValue({ user: { id: 1, role: "admin" } });
+    findSessionOriginChainMock.mockResolvedValue(null);
+
+    const { getSessionOriginChain } = await import("@/actions/session-origin-chain");
+    const result = await getSessionOriginChain("sess-not-found");
+
+    expect(result).toEqual({ ok: true, data: null });
+    expect(findSessionOriginChainMock).toHaveBeenCalledWith("sess-not-found");
+    expect(findKeyListMock).not.toHaveBeenCalled();
+    expect(dbSelectMock).not.toHaveBeenCalled();
+  });
+});