فهرست منبع

feat: Implement batch tagging functionality for channels

- Added a new endpoint to batch set tags for multiple channels, allowing users to update tags efficiently.
- Introduced a new `BatchSetChannelTag` function in the controller to handle incoming requests and validate parameters.
- Updated the `BatchSetChannelTag` method in the model to manage database transactions and ensure data integrity during tag updates.
- Enhanced the ChannelsTable component in the frontend to support batch tag setting, including UI elements for user interaction.
- Updated localization files to include new translation keys related to batch operations and tag settings.
CalciumIon 1 سال پیش
والد
کامیت
72d6898eb5
7فایلهای تغییر یافته به همراه201 افزوده شده و 24 حذف شده
  1. 28 1
      controller/channel.go
  2. 59 9
      model/ability.go
  3. 41 2
      model/channel.go
  4. 1 1
      router/api-router.go
  5. 64 6
      web/src/components/ChannelsTable.js
  6. 2 2
      web/src/i18n/locales/en copy.json
  7. 6 3
      web/src/i18n/locales/en.json

+ 28 - 1
controller/channel.go

@@ -419,7 +419,8 @@ func EditTagChannels(c *gin.Context) {
 }
 
 type ChannelBatch struct {
-	Ids []int `json:"ids"`
+	Ids []int   `json:"ids"`
+	Tag *string `json:"tag"`
 }
 
 func DeleteChannelBatch(c *gin.Context) {
@@ -570,3 +571,29 @@ func FetchModels(c *gin.Context) {
 		"data":    models,
 	})
 }
+
+func BatchSetChannelTag(c *gin.Context) {
+	channelBatch := ChannelBatch{}
+	err := c.ShouldBindJSON(&channelBatch)
+	if err != nil || len(channelBatch.Ids) == 0 {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "参数错误",
+		})
+		return
+	}
+	err = model.BatchSetChannelTag(channelBatch.Ids, channelBatch.Tag)
+	if err != nil {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": err.Error(),
+		})
+		return
+	}
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+		"data":    len(channelBatch.Ids),
+	})
+	return
+}

+ 59 - 9
model/ability.go

@@ -3,10 +3,11 @@ package model
 import (
 	"errors"
 	"fmt"
-	"github.com/samber/lo"
-	"gorm.io/gorm"
 	"one-api/common"
 	"strings"
+
+	"github.com/samber/lo"
+	"gorm.io/gorm"
 )
 
 type Ability struct {
@@ -173,18 +174,67 @@ func (channel *Channel) DeleteAbilities() error {
 
 // UpdateAbilities updates abilities of this channel.
 // Make sure the channel is completed before calling this function.
-func (channel *Channel) UpdateAbilities() error {
-	// A quick and dirty way to update abilities
+func (channel *Channel) UpdateAbilities(tx *gorm.DB) error {
+	isNewTx := false
+	// 如果没有传入事务,创建新的事务
+	if tx == nil {
+		tx = DB.Begin()
+		if tx.Error != nil {
+			return tx.Error
+		}
+		isNewTx = true
+		defer func() {
+			if r := recover(); r != nil {
+				tx.Rollback()
+			}
+		}()
+	}
+
 	// First delete all abilities of this channel
-	err := channel.DeleteAbilities()
+	err := tx.Where("channel_id = ?", channel.Id).Delete(&Ability{}).Error
 	if err != nil {
+		if isNewTx {
+			tx.Rollback()
+		}
 		return err
 	}
+
 	// Then add new abilities
-	err = channel.AddAbilities()
-	if err != nil {
-		return err
+	models_ := strings.Split(channel.Models, ",")
+	groups_ := strings.Split(channel.Group, ",")
+	abilities := make([]Ability, 0, len(models_))
+	for _, model := range models_ {
+		for _, group := range groups_ {
+			ability := Ability{
+				Group:     group,
+				Model:     model,
+				ChannelId: channel.Id,
+				Enabled:   channel.Status == common.ChannelStatusEnabled,
+				Priority:  channel.Priority,
+				Weight:    uint(channel.GetWeight()),
+				Tag:       channel.Tag,
+			}
+			abilities = append(abilities, ability)
+		}
 	}
+
+	if len(abilities) > 0 {
+		for _, chunk := range lo.Chunk(abilities, 50) {
+			err = tx.Create(&chunk).Error
+			if err != nil {
+				if isNewTx {
+					tx.Rollback()
+				}
+				return err
+			}
+		}
+	}
+
+	// 如果是新创建的事务,需要提交
+	if isNewTx {
+		return tx.Commit().Error
+	}
+
 	return nil
 }
 
@@ -246,7 +296,7 @@ func FixAbility() (int, error) {
 		return 0, err
 	}
 	for _, channel := range channels {
-		err := channel.UpdateAbilities()
+		err := channel.UpdateAbilities(nil)
 		if err != nil {
 			common.SysError(fmt.Sprintf("Update abilities of channel %d failed: %s", channel.Id, err.Error()))
 		} else {

+ 41 - 2
model/channel.go

@@ -257,7 +257,7 @@ func (channel *Channel) Update() error {
 		return err
 	}
 	DB.Model(channel).First(channel, "id = ?", channel.Id)
-	err = channel.UpdateAbilities()
+	err = channel.UpdateAbilities(nil)
 	return err
 }
 
@@ -389,7 +389,7 @@ func EditChannelByTag(tag string, newTag *string, modelMapping *string, models *
 		channels, err := GetChannelsByTag(updatedTag, false)
 		if err == nil {
 			for _, channel := range channels {
-				err = channel.UpdateAbilities()
+				err = channel.UpdateAbilities(nil)
 				if err != nil {
 					common.SysError("failed to update abilities: " + err.Error())
 				}
@@ -509,3 +509,42 @@ func (channel *Channel) SetSetting(setting map[string]interface{}) {
 	}
 	channel.Setting = string(settingBytes)
 }
+
+func GetChannelsByIds(ids []int) ([]*Channel, error) {
+	var channels []*Channel
+	err := DB.Where("id in (?)", ids).Find(&channels).Error
+	return channels, err
+}
+
+func BatchSetChannelTag(ids []int, tag *string) error {
+	// 开启事务
+	tx := DB.Begin()
+	if tx.Error != nil {
+		return tx.Error
+	}
+
+	// 更新标签
+	err := tx.Model(&Channel{}).Where("id in (?)", ids).Update("tag", tag).Error
+	if err != nil {
+		tx.Rollback()
+		return err
+	}
+
+	// update ability status
+	channels, err := GetChannelsByIds(ids)
+	if err != nil {
+		tx.Rollback()
+		return err
+	}
+
+	for _, channel := range channels {
+		err = channel.UpdateAbilities(tx)
+		if err != nil {
+			tx.Rollback()
+			return err
+		}
+	}
+
+	// 提交事务
+	return tx.Commit().Error
+}

+ 1 - 1
router/api-router.go

@@ -99,7 +99,7 @@ func SetApiRouter(router *gin.Engine) {
 			channelRoute.POST("/fix", controller.FixChannelsAbilities)
 			channelRoute.GET("/fetch_models/:id", controller.FetchUpstreamModels)
 			channelRoute.POST("/fetch_models", controller.FetchModels)
-
+			channelRoute.POST("/batch/tag", controller.BatchSetChannelTag)
 		}
 		tokenRoute := apiRouter.Group("/token")
 		tokenRoute.Use(middleware.UserAuth())

+ 64 - 6
web/src/components/ChannelsTable.js

@@ -162,9 +162,15 @@ const ChannelsTable = () => {
         return (
           <div>
             <Space spacing={2}>
-              {text?.split(',').map((item, index) => {
-                return renderGroup(item);
-              })}
+              {text?.split(',')
+                .sort((a, b) => {
+                  if (a === 'default') return -1;
+                  if (b === 'default') return 1;
+                  return a.localeCompare(b);
+                })
+                .map((item, index) => {
+                  return renderGroup(item);
+                })}
             </Space>
           </div>
         );
@@ -507,6 +513,8 @@ const ChannelsTable = () => {
   const [selectedChannels, setSelectedChannels] = useState([]);
   const [showEditPriority, setShowEditPriority] = useState(false);
   const [enableTagMode, setEnableTagMode] = useState(false);
+  const [showBatchSetTag, setShowBatchSetTag] = useState(false);
+  const [batchSetTagValue, setBatchSetTagValue] = useState('');
 
 
   const removeRecord = (record) => {
@@ -968,6 +976,29 @@ const ChannelsTable = () => {
     }
   };
 
+  const batchSetChannelTag = async () => {
+    if (selectedChannels.length === 0) {
+      showError(t('请先选择要设置标签的渠道!'));
+      return;
+    }
+    if (batchSetTagValue === '') {
+      showError(t('标签不能为空!'));
+      return;
+    }
+    let ids = selectedChannels.map(channel => channel.id);
+    const res = await API.post('/api/channel/batch/tag', {
+      ids: ids,
+      tag: batchSetTagValue === '' ? null : batchSetTagValue
+    });
+    if (res.data.success) {
+      showSuccess(t('已为 ${count} 个渠道设置标签!').replace('${count}', res.data.data));
+      await refresh();
+      setShowBatchSetTag(false);
+    } else {
+      showError(res.data.message);
+    }
+  };
+
   return (
     <>
       <EditTagModal
@@ -1115,11 +1146,11 @@ const ChannelsTable = () => {
       </div>
       <div style={{ marginTop: 20 }}>
         <Space>
-          <Typography.Text strong>{t('开启批量删除')}</Typography.Text>
+          <Typography.Text strong>{t('开启批量操作')}</Typography.Text>
           <Switch
-            label={t('开启批量删除')}
+            label={t('开启批量操作')}
             uncheckedText={t('关')}
-            aria-label={t('是否开启批量删除')}
+            aria-label={t('是否开启批量操作')}
             onChange={(v) => {
               setEnableBatchDelete(v);
             }}
@@ -1167,7 +1198,17 @@ const ChannelsTable = () => {
               loadChannels(0, pageSize, idSort, v);
             }}
           />
+          <Button
+        disabled={!enableBatchDelete}
+        theme="light"
+        type="primary"
+        style={{ marginRight: 8 }}
+        onClick={() => setShowBatchSetTag(true)}
+      >
+        {t('批量设置标签')}
+      </Button>
         </Space>
+
       </div>
 
 
@@ -1201,6 +1242,23 @@ const ChannelsTable = () => {
             : null
         }
       />
+      <Modal
+        title={t('批量设置标签')}
+        visible={showBatchSetTag}
+        onOk={batchSetChannelTag}
+        onCancel={() => setShowBatchSetTag(false)}
+        maskClosable={false}
+        centered={true}
+      >
+        <div style={{ marginBottom: 20 }}>
+          <Typography.Text>{t('请输入要设置的标签名称')}</Typography.Text>
+        </div>
+        <Input
+          placeholder={t('请输入标签名称')}
+          value={batchSetTagValue}
+          onChange={(v) => setBatchSetTagValue(v)}
+        />
+      </Modal>
     </>
   );
 };

+ 2 - 2
web/src/i18n/locales/en copy.json

@@ -546,8 +546,8 @@
   "是否用ID排序": "Whether to sort by ID",
   "确定?": "Sure?",
   "确定是否要删除禁用通道?": "Are you sure you want to delete the disabled channel?",
-  "开启批量删除": "Enable batch selection",
-  "是否开启批量删除": "Whether to enable batch selection",
+  "开启批量操作": "Enable batch selection",
+  "是否开启批量操作": "Whether to enable batch selection",
   "确定是否要删除所选通道?": "Are you sure you want to delete the selected channels?",
   "确定是否要修复数据库一致性?": "Are you sure you want to repair database consistency?",
   "进行该操作时,可能导致渠道访问错误,请仅在数据库出现问题时使用": "When performing this operation, it may cause channel access errors. Please only use it when there is a problem with the database.",

+ 6 - 3
web/src/i18n/locales/en.json

@@ -548,8 +548,8 @@
   "是否用ID排序": "Whether to sort by ID",
   "确定?": "Sure?",
   "确定是否要删除禁用通道?": "Are you sure you want to delete the disabled channel?",
-  "开启批量删除": "Enable batch selection",
-  "是否开启批量删除": "Whether to enable batch selection",
+  "开启批量操作": "Enable batch selection",
+  "是否开启批量操作": "Whether to enable batch selection",
   "确定是否要删除所选通道?": "Are you sure you want to delete the selected channels?",
   "确定是否要修复数据库一致性?": "Are you sure you want to repair database consistency?",
   "进行该操作时,可能导致渠道访问错误,请仅在数据库出现问题时使用": "When performing this operation, it may cause channel access errors. Please only use it when there is a problem with the database.",
@@ -1237,5 +1237,8 @@
   "更多": "Expand more",
   "个模型": "models",
   "可用模型": "Available models",
-  "时间范围": "Time range"
+  "时间范围": "Time range",
+  "批量设置标签": "Batch set tag",
+  "请输入要设置的标签名称": "Please enter the tag name to be set",
+  "请输入标签名称": "Please enter the tag name"
 }