Browse Source

feat: 完善标签编辑

CalciumIon 1 year ago
parent
commit
9c4d30602c
4 changed files with 303 additions and 41 deletions
  1. 31 9
      controller/channel.go
  2. 35 3
      model/channel.go
  3. 4 7
      web/src/pages/Channel/EditChannel.js
  4. 233 22
      web/src/pages/Channel/EditTagModal.js

+ 31 - 9
controller/channel.go

@@ -59,18 +59,31 @@ func GetAllChannels(c *gin.Context) {
 	}
 	tags := make(map[string]bool)
 	channelData := make([]*model.Channel, 0, len(channels))
+	tagChannels := make([]*model.Channel, 0)
 	for _, channel := range channels {
 		channelTag := channel.GetTag()
 		if channelTag != "" && !tags[channelTag] {
 			tags[channelTag] = true
-			tagChannels, err := model.GetChannelsByTag(channelTag)
+			tagChannel, err := model.GetChannelsByTag(channelTag)
 			if err == nil {
-				channelData = append(channelData, tagChannels...)
+				tagChannels = append(tagChannels, tagChannel...)
 			}
 		} else {
 			channelData = append(channelData, channel)
 		}
 	}
+	for i, channel := range tagChannels {
+		find := false
+		for _, can := range channelData {
+			if channel.Id == can.Id {
+				find = true
+				break
+			}
+		}
+		if !find {
+			channelData = append(channelData, tagChannels[i])
+		}
+	}
 	c.JSON(http.StatusOK, gin.H{
 		"success": true,
 		"message": "",
@@ -294,11 +307,13 @@ func DeleteDisabledChannel(c *gin.Context) {
 }
 
 type ChannelTag struct {
-	Tag        string  `json:"tag"`
-	NewTag     *string `json:"new_tag"`
-	Priority   *int64  `json:"priority"`
-	Weight     *uint   `json:"weight"`
-	MapMapping *string `json:"map_mapping"`
+	Tag          string  `json:"tag"`
+	NewTag       *string `json:"new_tag"`
+	Priority     *int64  `json:"priority"`
+	Weight       *uint   `json:"weight"`
+	ModelMapping *string `json:"map_mapping"`
+	Models       *string `json:"models"`
+	Groups       *string `json:"groups"`
 }
 
 func DisableTagChannels(c *gin.Context) {
@@ -354,14 +369,21 @@ func EnableTagChannels(c *gin.Context) {
 func EditTagChannels(c *gin.Context) {
 	channelTag := ChannelTag{}
 	err := c.ShouldBindJSON(&channelTag)
-	if err != nil || channelTag.Tag == "" {
+	if err != nil {
 		c.JSON(http.StatusOK, gin.H{
 			"success": false,
 			"message": "参数错误",
 		})
 		return
 	}
-	err = model.EditChannelByTag(channelTag.Tag, channelTag.NewTag, channelTag.Priority, channelTag.Weight)
+	if channelTag.Tag == "" {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": "tag不能为空",
+		})
+		return
+	}
+	err = model.EditChannelByTag(channelTag.Tag, channelTag.NewTag, channelTag.ModelMapping, channelTag.Models, channelTag.Groups, channelTag.Priority, channelTag.Weight)
 	if err != nil {
 		c.JSON(http.StatusOK, gin.H{
 			"success": false,

+ 35 - 3
model/channel.go

@@ -329,10 +329,25 @@ func DisableChannelByTag(tag string) error {
 	return err
 }
 
-func EditChannelByTag(tag string, newTag *string, priority *int64, weight *uint) error {
+func EditChannelByTag(tag string, newTag *string, modelMapping *string, models *string, group *string, priority *int64, weight *uint) error {
 	updateData := Channel{}
-	if newTag != nil {
+	shouldReCreateAbilities := false
+	updatedTag := tag
+	// 如果 newTag 不为空且不等于 tag,则更新 tag
+	if newTag != nil && *newTag != tag {
 		updateData.Tag = newTag
+		updatedTag = *newTag
+	}
+	if modelMapping != nil && *modelMapping != "" {
+		updateData.ModelMapping = modelMapping
+	}
+	if models != nil && *models != "" {
+		shouldReCreateAbilities = true
+		updateData.Models = *models
+	}
+	if group != nil && *group != "" {
+		shouldReCreateAbilities = true
+		updateData.Group = *group
 	}
 	if priority != nil {
 		updateData.Priority = priority
@@ -340,11 +355,28 @@ func EditChannelByTag(tag string, newTag *string, priority *int64, weight *uint)
 	if weight != nil {
 		updateData.Weight = weight
 	}
+
 	err := DB.Model(&Channel{}).Where("tag = ?", tag).Updates(updateData).Error
 	if err != nil {
 		return err
 	}
-	return UpdateAbilityByTag(tag, newTag, priority, weight)
+	if shouldReCreateAbilities {
+		channels, err := GetChannelsByTag(updatedTag)
+		if err == nil {
+			for _, channel := range channels {
+				err = channel.UpdateAbilities()
+				if err != nil {
+					common.SysError("failed to update abilities: " + err.Error())
+				}
+			}
+		}
+	} else {
+		err := UpdateAbilityByTag(tag, newTag, priority, weight)
+		if err != nil {
+			return err
+		}
+	}
+	return nil
 }
 
 func UpdateChannelUsedQuota(id int, quota int) {

+ 4 - 7
web/src/pages/Channel/EditChannel.js

@@ -28,9 +28,7 @@ import { getChannelModels, loadChannelModels } from '../../components/utils.js';
 import axios from 'axios';
 
 const MODEL_MAPPING_EXAMPLE = {
-  'gpt-3.5-turbo-0301': 'gpt-3.5-turbo',
-  'gpt-4-0314': 'gpt-4',
-  'gpt-4-32k-0314': 'gpt-4-32k'
+  'gpt-3.5-turbo': 'gpt-3.5-turbo-0125'
 };
 
 const STATUS_CODE_MAPPING_EXAMPLE = {
@@ -253,7 +251,7 @@ const EditChannel = (props) => {
       setBasicModels(
         res.data.data
           .filter((model) => {
-            return model.id.startsWith('gpt-3') || model.id.startsWith('text-');
+            return model.id.startsWith('gpt-') || model.id.startsWith('text-');
           })
           .map((model) => model.id)
       );
@@ -282,7 +280,7 @@ const EditChannel = (props) => {
   useEffect(() => {
     let localModelOptions = [...originModelOptions];
     inputs.models.forEach((model) => {
-      if (!localModelOptions.find((option) => option.key === model)) {
+      if (!localModelOptions.find((option) => option.label === model)) {
         localModelOptions.push({
           label: model,
           value: model
@@ -296,8 +294,7 @@ const EditChannel = (props) => {
     fetchModels().then();
     fetchGroups().then();
     if (isEdit) {
-      loadChannel().then(() => {
-      });
+      loadChannel().then(() => {});
     } else {
       setInputs(originInputs);
       let localModels = getChannelModels(inputs.type);

+ 233 - 22
web/src/pages/Channel/EditTagModal.js

@@ -1,46 +1,138 @@
 import React, { useState, useEffect } from 'react';
 import { API, showError, showSuccess } from '../../helpers';
-import { SideSheet, Space, Button, Input, Typography, Spin, Modal } from '@douyinfe/semi-ui';
+import { SideSheet, Space, Button, Input, Typography, Spin, Modal, Select, Banner, TextArea } from '@douyinfe/semi-ui';
 import TextInput from '../../components/TextInput.js';
+import { getChannelModels } from '../../components/utils.js';
+
+const MODEL_MAPPING_EXAMPLE = {
+  'gpt-3.5-turbo': 'gpt-3.5-turbo-0125'
+};
 
 const EditTagModal = (props) => {
   const { visible, tag, handleClose, refresh } = props;
   const [loading, setLoading] = useState(false);
+  const [originModelOptions, setOriginModelOptions] = useState([]);
+  const [modelOptions, setModelOptions] = useState([]);
+  const [groupOptions, setGroupOptions] = useState([]);
+  const [basicModels, setBasicModels] = useState([]);
+  const [fullModels, setFullModels] = useState([]);
   const originInputs = {
     tag: '',
     new_tag: null,
     model_mapping: null,
+    groups: [],
+    models: [],
   }
   const [inputs, setInputs] = useState(originInputs);
 
+  const handleInputChange = (name, value) => {
+    setInputs((inputs) => ({ ...inputs, [name]: value }));
+    if (name === 'type') {
+      let localModels = [];
+      switch (value) {
+        case 2:
+          localModels = [
+            'mj_imagine',
+            'mj_variation',
+            'mj_reroll',
+            'mj_blend',
+            'mj_upscale',
+            'mj_describe',
+            'mj_uploads'
+          ];
+          break;
+        case 5:
+          localModels = [
+            'swap_face',
+            'mj_imagine',
+            'mj_variation',
+            'mj_reroll',
+            'mj_blend',
+            'mj_upscale',
+            'mj_describe',
+            'mj_zoom',
+            'mj_shorten',
+            'mj_modal',
+            'mj_inpaint',
+            'mj_custom_zoom',
+            'mj_high_variation',
+            'mj_low_variation',
+            'mj_pan',
+            'mj_uploads'
+          ];
+          break;
+        case 36:
+          localModels = [
+            'suno_music',
+            'suno_lyrics'
+          ];
+          break;
+        default:
+          localModels = getChannelModels(value);
+          break;
+      }
+      if (inputs.models.length === 0) {
+        setInputs((inputs) => ({ ...inputs, models: localModels }));
+      }
+      setBasicModels(localModels);
+    }
+  };
+
+  const fetchModels = async () => {
+    try {
+      let res = await API.get(`/api/channel/models`);
+      let localModelOptions = res.data.data.map((model) => ({
+        label: model.id,
+        value: model.id
+      }));
+      setOriginModelOptions(localModelOptions);
+      setFullModels(res.data.data.map((model) => model.id));
+      setBasicModels(
+        res.data.data
+          .filter((model) => {
+            return model.id.startsWith('gpt-') || model.id.startsWith('text-');
+          })
+          .map((model) => model.id)
+      );
+    } catch (error) {
+      showError(error.message);
+    }
+  };
+
+  const fetchGroups = async () => {
+    try {
+      let res = await API.get(`/api/group/`);
+      if (res === undefined) {
+        return;
+      }
+      setGroupOptions(
+        res.data.data.map((group) => ({
+          label: group,
+          value: group
+        }))
+      );
+    } catch (error) {
+      showError(error.message);
+    }
+  };
+
 
   const handleSave = async () => {
     setLoading(true);
     let data = {
       tag: tag,
     }
-    if (inputs.newTag === tag) {
-      setLoading(false);
-      return;
-    }
     if (inputs.model_mapping !== null) {
-      data.model_mapping = inputs.model
+      data.model_mapping = inputs.model_mapping
     }
-    data.newTag = inputs.newTag;
-    if (data.newTag === '') {
-      Modal.confirm({
-        title: '解散标签',
-        content: '确定要解散标签吗?',
-        onCancel: () => {
-          setLoading(false);
-        },
-        onOk: async () => {
-          await submit(data);
-        }
-      });
-    } else {
-      await submit(data);
+    if (inputs.groups.length > 0) {
+      data.groups = inputs.groups.join(',');
     }
+    if (inputs.models.length > 0) {
+      data.models = inputs.models.join(',');
+    }
+    data.newTag = inputs.newTag;
+    await submit(data);
     setLoading(false);
   };
 
@@ -57,12 +149,27 @@ const EditTagModal = (props) => {
     }
   }
 
+  useEffect(() => {
+    let localModelOptions = [...originModelOptions];
+    inputs.models.forEach((model) => {
+      if (!localModelOptions.find((option) => option.label === model)) {
+        localModelOptions.push({
+          label: model,
+          value: model
+        });
+      }
+    });
+    setModelOptions(localModelOptions);
+  }, [originModelOptions, inputs.models]);
+
   useEffect(() => {
     setInputs({
       ...originInputs,
       tag: tag,
-      newTag: tag,
+      new_tag: tag,
     })
+    fetchModels().then();
+    fetchGroups().then();
   }, [visible]);
 
   return (
@@ -79,14 +186,118 @@ const EditTagModal = (props) => {
         </div>
       }
     >
+      <div style={{ marginTop: 10 }}>
+        <Banner
+          type={'warning'}
+          description={
+            <>
+              所有编辑均为覆盖操作,留空则不更改
+            </>
+          }
+        ></Banner>
+      </div>
       <Spin spinning={loading}>
         <TextInput
-          label="新标签(留空则解散标签,不会删除标签下的渠道)"
+          label="新标签,留空则不更改"
           name="newTag"
           value={inputs.new_tag}
           onChange={(value) => setInputs({ ...inputs, new_tag: value })}
           placeholder="请输入新标签"
         />
+        <div style={{ marginTop: 10 }}>
+          <Typography.Text strong>模型,留空则不更改:</Typography.Text>
+        </div>
+        <Select
+          placeholder={'请选择该渠道所支持的模型,留空则不更改'}
+          name="models"
+          required
+          multiple
+          selection
+          onChange={(value) => {
+            handleInputChange('models', value);
+          }}
+          value={inputs.models}
+          autoComplete="new-password"
+          optionList={modelOptions}
+        />
+        <div style={{ marginTop: 10 }}>
+          <Typography.Text strong>分组,留空则不更改:</Typography.Text>
+        </div>
+        <Select
+          placeholder={'请选择可以使用该渠道的分组,留空则不更改'}
+          name="groups"
+          required
+          multiple
+          selection
+          allowAdditions
+          additionLabel={'请在系统设置页面编辑分组倍率以添加新的分组:'}
+          onChange={(value) => {
+            handleInputChange('groups', value);
+          }}
+          value={inputs.groups}
+          autoComplete="new-password"
+          optionList={groupOptions}
+        />
+        <div style={{ marginTop: 10 }}>
+          <Typography.Text strong>模型重定向:</Typography.Text>
+        </div>
+        <TextArea
+          placeholder={`此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,留空则不更改`}
+          name="model_mapping"
+          onChange={(value) => {
+            handleInputChange('model_mapping', value);
+          }}
+          autosize
+          value={inputs.model_mapping}
+          autoComplete="new-password"
+        />
+        <Space>
+          <Typography.Text
+            style={{
+              color: 'rgba(var(--semi-blue-5), 1)',
+              userSelect: 'none',
+              cursor: 'pointer'
+            }}
+            onClick={() => {
+              handleInputChange(
+                'model_mapping',
+                JSON.stringify(MODEL_MAPPING_EXAMPLE, null, 2)
+              );
+            }}
+          >
+            填入模板
+          </Typography.Text>
+          <Typography.Text
+            style={{
+              color: 'rgba(var(--semi-blue-5), 1)',
+              userSelect: 'none',
+              cursor: 'pointer'
+            }}
+            onClick={() => {
+              handleInputChange(
+                'model_mapping',
+                JSON.stringify({}, null, 2)
+              );
+            }}
+          >
+            清空重定向
+          </Typography.Text>
+          <Typography.Text
+            style={{
+              color: 'rgba(var(--semi-blue-5), 1)',
+              userSelect: 'none',
+              cursor: 'pointer'
+            }}
+            onClick={() => {
+              handleInputChange(
+                'model_mapping',
+                ""
+              );
+            }}
+          >
+            不更改
+          </Typography.Text>
+        </Space>
       </Spin>
     </SideSheet>
   );