Ver Fonte

Add expires_at token tracking and /v2/chat/test admin endpoint

- Add expires_at column to track token expiration in database (all backends)
- Calculate expires_at from OIDC expiresIn response on token refresh
- Add /v2/chat/test endpoint with admin auth and optional account selection
- Only refresh token when expired, avoiding unnecessary API calls
- Add account selector dropdown in chat test tab UI
- Change chat test to use /v2/chat/test endpoint (no API key exposure)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <[email protected]>
Zhi Yang há 1 mês atrás
pai
commit
9c6e8f93c9
3 ficheiros alterados com 88 adições e 18 exclusões
  1. 33 8
      app.py
  2. 25 3
      db.py
  3. 30 7
      frontend/index.html

+ 33 - 8
app.py

@@ -434,6 +434,8 @@ async def refresh_access_token_in_db(account_id: str) -> Dict[str, Any]:
 
         new_access = data.get("accessToken")
         new_refresh = data.get("refreshToken", acc.get("refreshToken"))
+        expires_in = data.get("expiresIn", 3600)  # Default 1 hour if not provided
+        expires_at = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime(time.time() + expires_in))
         now = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
         status = "success"
     except httpx.HTTPError as e:
@@ -469,10 +471,10 @@ async def refresh_access_token_in_db(account_id: str) -> Dict[str, Any]:
     await _db.execute(
         """
         UPDATE accounts
-        SET accessToken=?, refreshToken=?, last_refresh_time=?, last_refresh_status=?, updated_at=?
+        SET accessToken=?, refreshToken=?, expires_at=?, last_refresh_time=?, last_refresh_status=?, updated_at=?
         WHERE id=?
         """,
-        (new_access, new_refresh, now, status, now, account_id),
+        (new_access, new_refresh, expires_at, now, status, now, account_id),
     )
 
     row2 = await _db.fetchone("SELECT * FROM accounts WHERE id=?", (account_id,))
@@ -1008,8 +1010,8 @@ async def _create_account_from_tokens(
     acc_id = str(uuid.uuid4())
     await _db.execute(
         """
-        INSERT INTO accounts (id, label, clientId, clientSecret, refreshToken, accessToken, other, last_refresh_time, last_refresh_status, created_at, updated_at, enabled)
-        VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
+        INSERT INTO accounts (id, label, clientId, clientSecret, refreshToken, accessToken, other, last_refresh_time, last_refresh_status, created_at, updated_at, enabled, expires_at)
+        VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
         """,
         (
             acc_id,
@@ -1024,6 +1026,7 @@ async def _create_account_from_tokens(
             now,
             now,
             1 if enabled else 0,
+            None,  # expires_at - will be set on first refresh
         ),
     )
     row = await _db.fetchone("SELECT * FROM accounts WHERE id=?", (acc_id,))
@@ -1176,8 +1179,8 @@ if CONSOLE_ENABLED:
         enabled_val = 1 if (body.enabled is None or body.enabled) else 0
         await _db.execute(
             """
-            INSERT INTO accounts (id, label, clientId, clientSecret, refreshToken, accessToken, other, last_refresh_time, last_refresh_status, created_at, updated_at, enabled)
-            VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
+            INSERT INTO accounts (id, label, clientId, clientSecret, refreshToken, accessToken, other, last_refresh_time, last_refresh_status, created_at, updated_at, enabled, expires_at)
+            VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
             """,
             (
                 acc_id,
@@ -1192,6 +1195,7 @@ if CONSOLE_ENABLED:
                 now,
                 now,
                 enabled_val,
+                None,  # expires_at - will be set on first refresh
             ),
         )
         row = await _db.fetchone("SELECT * FROM accounts WHERE id=?", (acc_id,))
@@ -1233,8 +1237,8 @@ if CONSOLE_ENABLED:
 
             await _db.execute(
                 """
-                INSERT INTO accounts (id, label, clientId, clientSecret, refreshToken, accessToken, other, last_refresh_time, last_refresh_status, created_at, updated_at, enabled)
-                VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
+                INSERT INTO accounts (id, label, clientId, clientSecret, refreshToken, accessToken, other, last_refresh_time, last_refresh_status, created_at, updated_at, enabled, expires_at)
+                VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
                 """,
                 (
                     acc_id,
@@ -1249,6 +1253,7 @@ if CONSOLE_ENABLED:
                     now,
                     now,
                     0,  # 初始为禁用状态
+                    None,  # expires_at - will be set on first refresh
                 ),
             )
             new_account_ids.append(acc_id)
@@ -1325,6 +1330,26 @@ if CONSOLE_ENABLED:
     async def manual_refresh(account_id: str, _: bool = Depends(verify_admin_password)):
         return await refresh_access_token_in_db(account_id)
 
+    @app.post("/v2/chat/test")
+    async def admin_chat_test(req: ChatCompletionRequest, account_id: Optional[str] = None, _: bool = Depends(verify_admin_password)):
+        """Admin chat test - uses admin auth, selects account by id or random."""
+        if account_id:
+            row = await _db.fetchone("SELECT * FROM accounts WHERE id=?", (account_id,))
+            if not row:
+                raise HTTPException(status_code=404, detail="Account not found")
+            account = _row_to_dict(row)
+            # Check if token is expired or missing
+            expires_at = account.get("expires_at")
+            now_str = time.strftime("%Y-%m-%dT%H:%M:%S", time.gmtime())
+            if not expires_at or expires_at <= now_str:
+                account = await refresh_access_token_in_db(account_id)
+        else:
+            candidates = await _list_enabled_accounts()
+            if not candidates:
+                raise HTTPException(status_code=503, detail="No enabled account available")
+            account = random.choice(candidates)
+        return await chat_completions(req, account)
+
     # ------------------------------------------------------------------------------
     # Simple Frontend (minimal dev test page; full UI in v2/frontend/index.html)
     # ------------------------------------------------------------------------------

+ 25 - 3
db.py

@@ -94,7 +94,8 @@ class SQLiteBackend(DatabaseBackend):
                 updated_at TEXT,
                 enabled INTEGER DEFAULT 1,
                 error_count INTEGER DEFAULT 0,
-                success_count INTEGER DEFAULT 0
+                success_count INTEGER DEFAULT 0,
+                expires_at TEXT
             )
         """)
         
@@ -103,6 +104,15 @@ class SQLiteBackend(DatabaseBackend):
         await self._conn.execute("CREATE INDEX IF NOT EXISTS idx_accounts_created_at ON accounts (created_at);")
         await self._conn.execute("CREATE INDEX IF NOT EXISTS idx_accounts_success_count ON accounts (success_count);")
 
+        # Add expires_at column if missing (migration)
+        try:
+            async with self._conn.execute("PRAGMA table_info(accounts)") as cursor:
+                cols = {row[1] for row in await cursor.fetchall()}
+                if "expires_at" not in cols:
+                    await self._conn.execute("ALTER TABLE accounts ADD COLUMN expires_at TEXT")
+        except Exception:
+            pass
+
         await self._conn.commit()
         self._initialized = True
 
@@ -160,9 +170,15 @@ class PostgresBackend(DatabaseBackend):
                     updated_at TEXT,
                     enabled INTEGER DEFAULT 1,
                     error_count INTEGER DEFAULT 0,
-                    success_count INTEGER DEFAULT 0
+                    success_count INTEGER DEFAULT 0,
+                    expires_at TEXT
                 )
             """)
+            # Add column if missing (migration)
+            try:
+                await conn.execute("ALTER TABLE accounts ADD COLUMN IF NOT EXISTS expires_at TEXT")
+            except Exception:
+                pass
         self._initialized = True
 
     async def close(self) -> None:
@@ -267,9 +283,15 @@ class MySQLBackend(DatabaseBackend):
                         updated_at TEXT,
                         enabled INT DEFAULT 1,
                         error_count INT DEFAULT 0,
-                        success_count INT DEFAULT 0
+                        success_count INT DEFAULT 0,
+                        expires_at TEXT
                     )
                 """)
+                # Add column if missing (migration)
+                try:
+                    await cur.execute("ALTER TABLE accounts ADD COLUMN expires_at TEXT")
+                except Exception:
+                    pass  # Column already exists
         self._initialized = True
 
     async def close(self) -> None:

+ 30 - 7
frontend/index.html

@@ -289,11 +289,11 @@
 
     <div id="tab-chat" class="tab-content">
       <div class="panel">
-        <h2>Chat 测试(OpenAI 兼容 /v1/chat/completions)</h2>
+        <h2>Chat 测试(/v2/chat/test)</h2>
         <div class="row">
           <div class="field" style="max-width:300px">
             <label>model</label>
-            <input id="model" value="claude-sonnet-4" />
+            <input id="model" value="claude-haiku-4.5" />
           </div>
           <div class="field" style="max-width:180px">
             <label>是否流式</label>
@@ -302,6 +302,12 @@
               <option value="true">true(SSE)</option>
             </select>
           </div>
+          <div class="field" style="max-width:200px">
+            <label>账号</label>
+            <select id="chatAccount">
+              <option value="">随机</option>
+            </select>
+          </div>
           <button class="right" onclick="send()">发送请求</button>
         </div>
         <div class="field">
@@ -621,7 +627,7 @@ class VirtualScroll {
 function renderAccounts(list){
   accountsData = list;
   const root = document.getElementById('accounts');
-  
+
   if (!Array.isArray(list) || list.length === 0) {
     if (virtualScroll) {
       virtualScroll.destroy();
@@ -634,10 +640,10 @@ function renderAccounts(list){
     root.appendChild(empty);
     return;
   }
-  
+
   // Estimate item height (card height ~280px with gap)
   const estimatedItemHeight = 292;
-  
+
   if (virtualScroll) {
     virtualScroll.update(list);
   } else {
@@ -645,6 +651,20 @@ function renderAccounts(list){
   }
 }
 
+function populateChatAccountSelector(accounts) {
+  const sel = document.getElementById('chatAccount');
+  if (!sel) return;
+  const current = sel.value;
+  sel.innerHTML = '<option value="">随机</option>';
+  (accounts || []).forEach(acc => {
+    const opt = document.createElement('option');
+    opt.value = acc.id;
+    opt.textContent = (acc.label || acc.id) + (acc.enabled ? '' : ' (禁用)');
+    sel.appendChild(opt);
+  });
+  if (current) sel.value = current;
+}
+
 async function loadAccounts(){
   try{
     const filter = document.querySelector('input[name="accountFilter"]:checked')?.value || 'all';
@@ -661,6 +681,7 @@ async function loadAccounts(){
     const j = await r.json();
     document.getElementById('accountCount').textContent = `(${j.count})`;
     renderAccounts(j.accounts);
+    populateChatAccountSelector(j.accounts);
   } catch(e){
     alert('加载账户失败:' + e);
   }
@@ -783,6 +804,7 @@ async function refreshAccount(id){
 async function send() {
   const model = document.getElementById('model').value.trim();
   const stream = document.getElementById('stream').value === 'true';
+  const accountId = document.getElementById('chatAccount')?.value || '';
   const out = document.getElementById('out');
   out.textContent = '';
 
@@ -792,9 +814,10 @@ async function send() {
 
   const body = { model, messages, stream };
   const headers = { 'content-type': 'application/json' };
+  const endpoint = accountId ? `/v2/chat/test?account_id=${encodeURIComponent(accountId)}` : '/v2/chat/test';
 
   if (!stream) {
-    const r = await authFetch(api('/v1/chat/completions'), {
+    const r = await authFetch(api(endpoint), {
       method:'POST',
       headers,
       body: JSON.stringify(body)
@@ -803,7 +826,7 @@ async function send() {
     try { out.textContent = JSON.stringify(JSON.parse(text), null, 2); }
     catch { out.textContent = text; }
   } else {
-    const r = await authFetch(api('/v1/chat/completions'), {
+    const r = await authFetch(api(endpoint), {
       method:'POST',
       headers,
       body: JSON.stringify(body)