Browse Source

feat(repo): add findSessionOriginChain repository function

ding113 1 week ago
parent
commit
b685f2d4a2
1 changed files with 24 additions and 1 deletions
  1. 24 1
      src/repository/message.ts

+ 24 - 1
src/repository/message.ts

@@ -5,7 +5,7 @@ import { db } from "@/drizzle/db";
 import { keys as keysTable, messageRequest, providers, users } from "@/drizzle/schema";
 import { getEnvConfig } from "@/lib/config/env.schema";
 import { formatCostForStorage } from "@/lib/utils/currency";
-import type { CreateMessageRequestData, MessageRequest } from "@/types/message";
+import type { CreateMessageRequestData, MessageRequest, ProviderChainItem } from "@/types/message";
 import type { SpecialSetting } from "@/types/special-settings";
 import { EXCLUDE_WARMUP_CONDITION } from "./_shared/message-request-conditions";
 import { toMessageRequest } from "./_shared/transformers";
@@ -277,6 +277,29 @@ export async function findMessageRequestBySessionId(
   return toMessageRequest(result);
 }
 
+export async function findSessionOriginChain(
+  sessionId: string
+): Promise<ProviderChainItem[] | null> {
+  const [row] = await db
+    .select({
+      providerChain: messageRequest.providerChain,
+    })
+    .from(messageRequest)
+    .where(
+      and(
+        eq(messageRequest.sessionId, sessionId),
+        isNull(messageRequest.deletedAt),
+        EXCLUDE_WARMUP_CONDITION,
+        sql`${messageRequest.providerChain} IS NOT NULL`
+      )
+    )
+    .orderBy(asc(messageRequest.requestSequence))
+    .limit(1);
+
+  if (!row?.providerChain) return null;
+  return row.providerChain as ProviderChainItem[];
+}
+
 /**
  * 按 (sessionId, requestSequence) 获取请求的审计字段(用于 Session 详情页补齐特殊设置展示)
  */