浏览代码

feat: 渠道新可选是否自动禁用功能

CaIon 2 年之前
父节点
当前提交
80271b33ba
共有 6 个文件被更改,包括 45 次插入4 次删除
  1. 6 1
      controller/channel-test.go
  2. 1 1
      controller/relay-utils.go
  3. 2 1
      controller/relay.go
  4. 1 0
      middleware/distributor.go
  5. 1 0
      model/channel.go
  6. 34 1
      web/src/pages/Channel/EditChannel.js

+ 6 - 1
controller/channel-test.go

@@ -183,7 +183,12 @@ func testAllChannels(notify bool) error {
 				err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
 				disableChannel(channel.Id, channel.Name, err.Error())
 			}
-			if shouldDisableChannel(openaiErr, -1) {
+			ban := true
+			// parse *int to bool
+			if channel.AutoBan != nil && *channel.AutoBan == 0 {
+				ban = false
+			}
+			if shouldDisableChannel(openaiErr, -1) && ban {
 				disableChannel(channel.Id, channel.Name, err.Error())
 			}
 			channel.UpdateResponseTime(milliseconds)

+ 1 - 1
controller/relay-utils.go

@@ -128,7 +128,7 @@ func shouldDisableChannel(err *OpenAIError, statusCode int) bool {
 	if err == nil {
 		return false
 	}
-	if statusCode == http.StatusUnauthorized {
+	if statusCode == http.StatusUnauthorized || statusCode == http.StatusTooManyRequests {
 		return true
 	}
 	if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" {

+ 2 - 1
controller/relay.go

@@ -236,9 +236,10 @@ func Relay(c *gin.Context) {
 			})
 		}
 		channelId := c.GetInt("channel_id")
+		autoBan := c.GetBool("auto_ban")
 		common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
 		// https://platform.openai.com/docs/guides/error-codes/api-errors
-		if shouldDisableChannel(&err.OpenAIError, err.StatusCode) {
+		if shouldDisableChannel(&err.OpenAIError, err.StatusCode) && autoBan {
 			channelId := c.GetInt("channel_id")
 			channelName := c.GetString("channel_name")
 			disableChannel(channelId, channelName, err.Message)

+ 1 - 0
middleware/distributor.go

@@ -87,6 +87,7 @@ func Distribute() func(c *gin.Context) {
 		c.Set("channel", channel.Type)
 		c.Set("channel_id", channel.Id)
 		c.Set("channel_name", channel.Name)
+		c.Set("auto_ban", channel.AutoBan)
 		c.Set("model_mapping", channel.GetModelMapping())
 		c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
 		c.Set("base_url", channel.GetBaseURL())

+ 1 - 0
model/channel.go

@@ -25,6 +25,7 @@ type Channel struct {
 	UsedQuota          int64   `json:"used_quota" gorm:"bigint;default:0"`
 	ModelMapping       *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
 	Priority           *int64  `json:"priority" gorm:"bigint;default:0"`
+	AutoBan            *int    `json:"auto_ban" gorm:"default:1"`
 }
 
 func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {

+ 34 - 1
web/src/pages/Channel/EditChannel.js

@@ -43,9 +43,12 @@ const EditChannel = () => {
     other: '',
     model_mapping: '',
     models: [],
+    auto_ban: 1,
     groups: ['default']
   };
   const [batch, setBatch] = useState(false);
+  const [autoBan, setAutoBan] = useState(true);
+  // const [autoBan, setAutoBan] = useState(true);
   const [inputs, setInputs] = useState(originInputs);
   const [originModelOptions, setOriginModelOptions] = useState([]);
   const [modelOptions, setModelOptions] = useState([]);
@@ -82,6 +85,7 @@ const EditChannel = () => {
       }
       setInputs((inputs) => ({ ...inputs, models: localModels }));
     }
+    //setAutoBan
   };
 
   const loadChannel = async () => {
@@ -102,6 +106,12 @@ const EditChannel = () => {
         data.model_mapping = JSON.stringify(JSON.parse(data.model_mapping), null, 2);
       }
       setInputs(data);
+      if (data.auto_ban === 0) {
+        setAutoBan(false);
+      } else {
+        setAutoBan(true);
+      }
+      // console.log(data);
     } else {
       showError(message);
     }
@@ -161,6 +171,11 @@ const EditChannel = () => {
     fetchGroups().then();
   }, []);
 
+  useEffect(() => {
+    setInputs((inputs) => ({ ...inputs, auto_ban: autoBan ? 1 : 0 }));
+    console.log(autoBan);
+  }, [autoBan]);
+
   const submit = async () => {
     if (!isEdit && (inputs.name === '' || inputs.key === '')) {
       showInfo('请填写渠道名称和渠道密钥!');
@@ -185,6 +200,11 @@ const EditChannel = () => {
       localInputs.other = 'v2.1';
     }
     let res;
+    if (!Array.isArray(localInputs.models)) {
+        showError('提交失败,请勿重复提交!');
+        handleCancel();
+        return;
+    }
     localInputs.models = localInputs.models.join(',');
     localInputs.group = localInputs.groups.join(',');
     if (isEdit) {
@@ -423,7 +443,20 @@ const EditChannel = () => {
                 placeholder='请输入组织org-xxx'
                 onChange={handleInputChange}
                 value={inputs.openai_organization}
-                autoComplete='new-password'
+            />
+          </Form.Field>
+          <Form.Field>
+            <Form.Checkbox
+                label='是否自动禁用(仅当自动禁用开启时有效),关闭后不会自动禁用该渠道'
+                name='auto_ban'
+                checked={autoBan}
+                onChange={
+                    () => {
+                        setAutoBan(!autoBan);
+
+                    }
+                }
+                // onChange={handleInputChange}
             />
           </Form.Field>
           {