Browse Source

feat: supprt channel priority now & record channel id in log (#484)

* feat: 支持设置渠道优先级 & 日志中显示使用的渠道ID

* fix: 设置渠道优先级未更新 ability

* chore: update implementation

---------

Co-authored-by: Xiangyuan Liu <[email protected]>
Co-authored-by: JustSong <[email protected]>
Co-authored-by: JustSong <[email protected]>
Xyfacai 2 years ago
parent
commit
ecf8a6d875

+ 6 - 3
controller/log.go

@@ -19,7 +19,8 @@ func GetAllLogs(c *gin.Context) {
 	username := c.Query("username")
 	tokenName := c.Query("token_name")
 	modelName := c.Query("model_name")
-	logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage)
+	channel, _ := strconv.Atoi(c.Query("channel"))
+	logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage, channel)
 	if err != nil {
 		c.JSON(http.StatusOK, gin.H{
 			"success": false,
@@ -106,7 +107,8 @@ func GetLogsStat(c *gin.Context) {
 	tokenName := c.Query("token_name")
 	username := c.Query("username")
 	modelName := c.Query("model_name")
-	quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
+	channel, _ := strconv.Atoi(c.Query("channel"))
+	quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel)
 	//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "")
 	c.JSON(http.StatusOK, gin.H{
 		"success": true,
@@ -126,7 +128,8 @@ func GetLogsSelfStat(c *gin.Context) {
 	endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
 	tokenName := c.Query("token_name")
 	modelName := c.Query("model_name")
-	quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
+	channel, _ := strconv.Atoi(c.Query("channel"))
+	quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel)
 	//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
 	c.JSON(http.StatusOK, gin.H{
 		"success": true,

+ 2 - 1
controller/relay-audio.go

@@ -18,6 +18,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 
 	tokenId := c.GetInt("token_id")
 	channelType := c.GetInt("channel")
+	channelId := c.GetInt("channel_id")
 	userId := c.GetInt("id")
 	group := c.GetString("group")
 
@@ -107,7 +108,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 			if quota != 0 {
 				tokenName := c.GetString("token_name")
 				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
-				model.RecordConsumeLog(ctx, userId, 0, 0, audioModel, tokenName, quota, logContent)
+				model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioModel, tokenName, quota, logContent)
 				model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
 				channelId := c.GetInt("channel_id")
 				model.UpdateChannelUsedQuota(channelId, quota)

+ 2 - 1
controller/relay-image.go

@@ -19,6 +19,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 
 	tokenId := c.GetInt("token_id")
 	channelType := c.GetInt("channel")
+	channelId := c.GetInt("channel_id")
 	userId := c.GetInt("id")
 	consumeQuota := c.GetBool("consume_quota")
 	group := c.GetString("group")
@@ -138,7 +139,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
 			if quota != 0 {
 				tokenName := c.GetString("token_name")
 				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
-				model.RecordConsumeLog(ctx, userId, 0, 0, imageModel, tokenName, quota, logContent)
+				model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent)
 				model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
 				channelId := c.GetInt("channel_id")
 				model.UpdateChannelUsedQuota(channelId, quota)

+ 2 - 2
controller/relay-text.go

@@ -38,6 +38,7 @@ func init() {
 
 func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 	channelType := c.GetInt("channel")
+	channelId := c.GetInt("channel_id")
 	tokenId := c.GetInt("token_id")
 	userId := c.GetInt("id")
 	consumeQuota := c.GetBool("consume_quota")
@@ -364,7 +365,6 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 
 	var textResponse TextResponse
 	tokenName := c.GetString("token_name")
-	channelId := c.GetInt("channel_id")
 
 	defer func(ctx context.Context) {
 		// c.Writer.Flush()
@@ -397,7 +397,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 				}
 				if quota != 0 {
 					logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
-					model.RecordConsumeLog(ctx, userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
+					model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
 					model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
 					model.UpdateChannelUsedQuota(channelId, quota)
 				}

+ 4 - 2
model/ability.go

@@ -10,15 +10,16 @@ type Ability struct {
 	Model     string `json:"model" gorm:"primaryKey;autoIncrement:false"`
 	ChannelId int    `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
 	Enabled   bool   `json:"enabled"`
+	Priority  int64  `json:"priority" gorm:"bigint;default:0"`
 }
 
 func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
 	ability := Ability{}
 	var err error = nil
 	if common.UsingSQLite {
-		err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RANDOM()").Limit(1).First(&ability).Error
+		err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("CASE WHEN priority <> 0 THEN priority ELSE RANDOM() END DESC ").Limit(1).First(&ability).Error
 	} else {
-		err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RAND()").Limit(1).First(&ability).Error
+		err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("CASE WHEN priority <> 0 THEN priority ELSE RAND() END DESC").Limit(1).First(&ability).Error
 	}
 	if err != nil {
 		return nil, err
@@ -40,6 +41,7 @@ func (channel *Channel) AddAbilities() error {
 				Model:     model,
 				ChannelId: channel.Id,
 				Enabled:   channel.Status == common.ChannelStatusEnabled,
+				Priority:  channel.Priority,
 			}
 			abilities = append(abilities, ability)
 		}

+ 17 - 0
model/cache.go

@@ -6,6 +6,7 @@ import (
 	"fmt"
 	"math/rand"
 	"one-api/common"
+	"sort"
 	"strconv"
 	"strings"
 	"sync"
@@ -159,6 +160,17 @@ func InitChannelCache() {
 			}
 		}
 	}
+
+	// sort by priority
+	for group, model2channels := range newGroup2model2channels {
+		for model, channels := range model2channels {
+			sort.Slice(channels, func(i, j int) bool {
+				return channels[i].Priority > channels[j].Priority
+			})
+			newGroup2model2channels[group][model] = channels
+		}
+	}
+
 	channelSyncLock.Lock()
 	group2model2channels = newGroup2model2channels
 	channelSyncLock.Unlock()
@@ -183,6 +195,11 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
 	if len(channels) == 0 {
 		return nil, errors.New("channel not found")
 	}
+	// choose by priority
+	firstChannel := channels[0]
+	if firstChannel.Priority > 0 {
+		return firstChannel, nil
+	}
 	idx := rand.Intn(len(channels))
 	return channels[idx], nil
 }

+ 1 - 0
model/channel.go

@@ -23,6 +23,7 @@ type Channel struct {
 	Group              string  `json:"group" gorm:"type:varchar(32);default:'default'"`
 	UsedQuota          int64   `json:"used_quota" gorm:"bigint;default:0"`
 	ModelMapping       string  `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
+	Priority           int64   `json:"priority" gorm:"bigint;default:0"`
 }
 
 func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {

+ 13 - 4
model/log.go

@@ -19,6 +19,7 @@ type Log struct {
 	Quota            int    `json:"quota" gorm:"default:0"`
 	PromptTokens     int    `json:"prompt_tokens" gorm:"default:0"`
 	CompletionTokens int    `json:"completion_tokens" gorm:"default:0"`
+	Channel          int    `json:"channel" gorm:"default:0"`
 }
 
 const (
@@ -46,8 +47,9 @@ func RecordLog(userId int, logType int, content string) {
 	}
 }
 
-func RecordConsumeLog(ctx context.Context, userId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) {
-	common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, promptTokens, completionTokens, modelName, tokenName, quota, content))
+
+func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) {
+	common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
 	if !common.LogConsumeEnabled {
 		return
 	}
@@ -62,6 +64,7 @@ func RecordConsumeLog(ctx context.Context, userId int, promptTokens int, complet
 		TokenName:        tokenName,
 		ModelName:        modelName,
 		Quota:            quota,
+		Channel:          channelId,
 	}
 	err := DB.Create(log).Error
 	if err != nil {
@@ -69,7 +72,7 @@ func RecordConsumeLog(ctx context.Context, userId int, promptTokens int, complet
 	}
 }
 
-func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int) (logs []*Log, err error) {
+func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) {
 	var tx *gorm.DB
 	if logType == LogTypeUnknown {
 		tx = DB
@@ -91,6 +94,9 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
 	if endTimestamp != 0 {
 		tx = tx.Where("created_at <= ?", endTimestamp)
 	}
+	if channel != 0 {
+		tx = tx.Where("channel = ?", channel)
+	}
 	err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error
 	return logs, err
 }
@@ -128,7 +134,7 @@ func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
 	return logs, err
 }
 
-func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (quota int) {
+func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int) {
 	tx := DB.Table("logs").Select("sum(quota)")
 	if username != "" {
 		tx = tx.Where("username = ?", username)
@@ -145,6 +151,9 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
 	if modelName != "" {
 		tx = tx.Where("model_name = ?", modelName)
 	}
+	if channel != 0 {
+		tx = tx.Where("channel = ?", channel)
+	}
 	tx.Where("type = ?", LogTypeConsume).Scan(&quota)
 	return quota
 }

+ 35 - 4
web/src/components/ChannelsTable.js

@@ -1,5 +1,5 @@
 import React, { useEffect, useState } from 'react';
-import { Button, Form, Label, Pagination, Popup, Table } from 'semantic-ui-react';
+import {Button, Form, Input, Label, Pagination, Popup, Table} from 'semantic-ui-react';
 import { Link } from 'react-router-dom';
 import { API, showError, showInfo, showNotice, showSuccess, timestamp2string } from '../helpers';
 
@@ -24,7 +24,7 @@ function renderType(type) {
     }
     type2label[0] = { value: 0, text: '未知类型', color: 'grey' };
   }
-  return <Label basic color={type2label[type].color}>{type2label[type].text}</Label>;
+  return <Label basic color={type2label[type]?.color}>{type2label[type]?.text}</Label>;
 }
 
 function renderBalance(type, balance) {
@@ -96,7 +96,7 @@ const ChannelsTable = () => {
       });
   }, []);
 
-  const manageChannel = async (id, action, idx) => {
+  const manageChannel = async (id, action, idx, priority) => {
     let data = { id };
     let res;
     switch (action) {
@@ -111,6 +111,13 @@ const ChannelsTable = () => {
         data.status = 2;
         res = await API.put('/api/channel/', data);
         break;
+      case 'priority':
+        if (priority === '') {
+          return;
+        }
+        data.priority = parseInt(priority);
+        res = await API.put('/api/channel/', data);
+        break;
     }
     const { success, message } = res.data;
     if (success) {
@@ -335,6 +342,14 @@ const ChannelsTable = () => {
             >
               余额
             </Table.HeaderCell>
+            <Table.HeaderCell
+                style={{ cursor: 'pointer' }}
+                onClick={() => {
+                  sortChannel('priority');
+                }}
+            >
+              优先级
+            </Table.HeaderCell>
             <Table.HeaderCell>操作</Table.HeaderCell>
           </Table.Row>
         </Table.Header>
@@ -373,6 +388,22 @@ const ChannelsTable = () => {
                       basic
                     />
                   </Table.Cell>
+                  <Table.Cell>
+                    <Popup
+                        trigger={<Input type="number"  defaultValue={channel.priority} onBlur={(event) => {
+                          manageChannel(
+                              channel.id,
+                              'priority',
+                              idx,
+                              event.target.value,
+                          );
+                        }}>
+                          <input style={{maxWidth:'60px'}} />
+                        </Input>}
+                        content='渠道选择优先级,越高越优先'
+                        basic
+                    />
+                  </Table.Cell>
                   <Table.Cell>
                     <div>
                       <Button
@@ -441,7 +472,7 @@ const ChannelsTable = () => {
 
         <Table.Footer>
           <Table.Row>
-            <Table.HeaderCell colSpan='8'>
+            <Table.HeaderCell colSpan='9'>
               <Button size='small' as={Link} to='/channel/add' loading={loading}>
                 添加新的渠道
               </Button>

+ 40 - 17
web/src/components/LogsTable.js

@@ -56,9 +56,10 @@ const LogsTable = () => {
     token_name: '',
     model_name: '',
     start_timestamp: timestamp2string(0),
-    end_timestamp: timestamp2string(now.getTime() / 1000 + 3600)
+    end_timestamp: timestamp2string(now.getTime() / 1000 + 3600),
+    channel: ''
   });
-  const { username, token_name, model_name, start_timestamp, end_timestamp } = inputs;
+  const { username, token_name, model_name, start_timestamp, end_timestamp, channel } = inputs;
 
   const [stat, setStat] = useState({
     quota: 0,
@@ -84,7 +85,7 @@ const LogsTable = () => {
   const getLogStat = async () => {
     let localStartTimestamp = Date.parse(start_timestamp) / 1000;
     let localEndTimestamp = Date.parse(end_timestamp) / 1000;
-    let res = await API.get(`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`);
+    let res = await API.get(`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`);
     const { success, message, data } = res.data;
     if (success) {
       setStat(data);
@@ -109,7 +110,7 @@ const LogsTable = () => {
     let localStartTimestamp = Date.parse(start_timestamp) / 1000;
     let localEndTimestamp = Date.parse(end_timestamp) / 1000;
     if (isAdminUser) {
-      url = `/api/log/?p=${startIdx}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
+      url = `/api/log/?p=${startIdx}&type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`;
     } else {
       url = `/api/log/self/?p=${startIdx}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
     }
@@ -205,16 +206,9 @@ const LogsTable = () => {
         </Header>
         <Form>
           <Form.Group>
-            {
-              isAdminUser && (
-                <Form.Input fluid label={'用户名称'} width={2} value={username}
-                            placeholder={'可选值'} name='username'
-                            onChange={handleInputChange} />
-              )
-            }
-            <Form.Input fluid label={'令牌名称'} width={isAdminUser ? 2 : 3} value={token_name}
+            <Form.Input fluid label={'令牌名称'} width={3} value={token_name}
                         placeholder={'可选值'} name='token_name' onChange={handleInputChange} />
-            <Form.Input fluid label='模型名称' width={isAdminUser ? 2 : 3} value={model_name} placeholder='可选值'
+            <Form.Input fluid label='模型名称' width={3} value={model_name} placeholder='可选值'
                         name='model_name'
                         onChange={handleInputChange} />
             <Form.Input fluid label='起始时间' width={4} value={start_timestamp} type='datetime-local'
@@ -225,6 +219,19 @@ const LogsTable = () => {
                         onChange={handleInputChange} />
             <Form.Button fluid label='操作' width={2} onClick={refresh}>查询</Form.Button>
           </Form.Group>
+          {
+            isAdminUser && <>
+              <Form.Group>
+                <Form.Input fluid label={'渠道 ID'} width={3} value={channel}
+                            placeholder='可选值' name='channel'
+                            onChange={handleInputChange} />
+                <Form.Input fluid label={'用户名称'} width={3} value={username}
+                            placeholder={'可选值'} name='username'
+                            onChange={handleInputChange} />
+
+              </Form.Group>
+            </>
+          }
         </Form>
         <Table basic compact size='small'>
           <Table.Header>
@@ -238,6 +245,17 @@ const LogsTable = () => {
               >
                 时间
               </Table.HeaderCell>
+              {
+                isAdminUser && <Table.HeaderCell
+                  style={{ cursor: 'pointer' }}
+                  onClick={() => {
+                    sortLog('channel');
+                  }}
+                  width={1}
+                >
+                  渠道
+                </Table.HeaderCell>
+              }
               {
                 isAdminUser && <Table.HeaderCell
                   style={{ cursor: 'pointer' }}
@@ -299,16 +317,16 @@ const LogsTable = () => {
                 onClick={() => {
                   sortLog('quota');
                 }}
-                width={2}
+                width={1}
               >
-                消耗额度
+                额度
               </Table.HeaderCell>
               <Table.HeaderCell
                 style={{ cursor: 'pointer' }}
                 onClick={() => {
                   sortLog('content');
                 }}
-                width={isAdminUser ? 4 : 5}
+                width={isAdminUser ? 4 : 6}
               >
                 详情
               </Table.HeaderCell>
@@ -326,6 +344,11 @@ const LogsTable = () => {
                 return (
                   <Table.Row key={log.id}>
                     <Table.Cell>{renderTimestamp(log.created_at)}</Table.Cell>
+                    {
+                      isAdminUser && (
+                        <Table.Cell>{log.channel ? <Label basic>{log.channel}</Label> : ''}</Table.Cell>
+                      )
+                    }
                     {
                       isAdminUser && (
                         <Table.Cell>{log.username ? <Label>{log.username}</Label> : ''}</Table.Cell>
@@ -345,7 +368,7 @@ const LogsTable = () => {
 
           <Table.Footer>
             <Table.Row>
-              <Table.HeaderCell colSpan={'9'}>
+              <Table.HeaderCell colSpan={'10'}>
                 <Select
                   placeholder='选择明细分类'
                   options={LOG_OPTIONS}