Просмотр исходного кода

Merge branch 'tbphp-tbphp_model_request_rate_limit_for_group'

creamlike1024 7 месяцев назад
Родитель
Сommit
9de24668d8

+ 9 - 0
controller/option.go

@@ -110,6 +110,15 @@ func UpdateOption(c *gin.Context) {
 			})
 			return
 		}
+	case "ModelRequestRateLimitGroup":
+		err = setting.CheckModelRequestRateLimitGroup(option.Value)
+		if err != nil {
+			c.JSON(http.StatusOK, gin.H{
+				"success": false,
+				"message": err.Error(),
+			})
+			return
+		}
 
 	}
 	err = model.UpdateOption(option.Key, option.Value)

+ 14 - 0
middleware/model-rate-limit.go

@@ -6,6 +6,7 @@ import (
 	"net/http"
 	"one-api/common"
 	"one-api/common/limiter"
+	"one-api/constant"
 	"one-api/setting"
 	"strconv"
 	"time"
@@ -175,6 +176,19 @@ func ModelRequestRateLimit() func(c *gin.Context) {
 		totalMaxCount := setting.ModelRequestRateLimitCount
 		successMaxCount := setting.ModelRequestRateLimitSuccessCount
 
+		// 获取分组
+		group := c.GetString("token_group")
+		if group == "" {
+			group = c.GetString(constant.ContextKeyUserGroup)
+		}
+
+		//获取分组的限流配置
+		groupTotalCount, groupSuccessCount, found := setting.GetGroupRateLimit(group)
+		if found {
+			totalMaxCount = groupTotalCount
+			successMaxCount = groupSuccessCount
+		}
+
 		// 根据存储类型选择并执行限流处理器
 		if common.RedisEnabled {
 			redisRateLimitHandler(duration, totalMaxCount, successMaxCount)(c)

+ 3 - 0
model/option.go

@@ -92,6 +92,7 @@ func InitOptionMap() {
 	common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount)
 	common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes)
 	common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount)
+	common.OptionMap["ModelRequestRateLimitGroup"] = setting.ModelRequestRateLimitGroup2JSONString()
 	common.OptionMap["ModelRatio"] = operation_setting.ModelRatio2JSONString()
 	common.OptionMap["ModelPrice"] = operation_setting.ModelPrice2JSONString()
 	common.OptionMap["CacheRatio"] = operation_setting.CacheRatio2JSONString()
@@ -338,6 +339,8 @@ func updateOptionMap(key string, value string) (err error) {
 		setting.ModelRequestRateLimitDurationMinutes, _ = strconv.Atoi(value)
 	case "ModelRequestRateLimitSuccessCount":
 		setting.ModelRequestRateLimitSuccessCount, _ = strconv.Atoi(value)
+	case "ModelRequestRateLimitGroup":
+		err = setting.UpdateModelRequestRateLimitGroupByJSONString(value)
 	case "RetryTimes":
 		common.RetryTimes, _ = strconv.Atoi(value)
 	case "DataExportInterval":

+ 58 - 0
setting/rate_limit.go

@@ -1,6 +1,64 @@
 package setting
 
+import (
+	"encoding/json"
+	"fmt"
+	"one-api/common"
+	"sync"
+)
+
 var ModelRequestRateLimitEnabled = false
 var ModelRequestRateLimitDurationMinutes = 1
 var ModelRequestRateLimitCount = 0
 var ModelRequestRateLimitSuccessCount = 1000
+var ModelRequestRateLimitGroup = map[string][2]int{}
+var ModelRequestRateLimitMutex sync.RWMutex
+
+func ModelRequestRateLimitGroup2JSONString() string {
+	ModelRequestRateLimitMutex.RLock()
+	defer ModelRequestRateLimitMutex.RUnlock()
+
+	jsonBytes, err := json.Marshal(ModelRequestRateLimitGroup)
+	if err != nil {
+		common.SysError("error marshalling model ratio: " + err.Error())
+	}
+	return string(jsonBytes)
+}
+
+func UpdateModelRequestRateLimitGroupByJSONString(jsonStr string) error {
+	ModelRequestRateLimitMutex.RLock()
+	defer ModelRequestRateLimitMutex.RUnlock()
+
+	ModelRequestRateLimitGroup = make(map[string][2]int)
+	return json.Unmarshal([]byte(jsonStr), &ModelRequestRateLimitGroup)
+}
+
+func GetGroupRateLimit(group string) (totalCount, successCount int, found bool) {
+	ModelRequestRateLimitMutex.RLock()
+	defer ModelRequestRateLimitMutex.RUnlock()
+
+	if ModelRequestRateLimitGroup == nil {
+		return 0, 0, false
+	}
+
+	limits, found := ModelRequestRateLimitGroup[group]
+	if !found {
+		return 0, 0, false
+	}
+	return limits[0], limits[1], true
+}
+
+func CheckModelRequestRateLimitGroup(jsonStr string) error {
+	checkModelRequestRateLimitGroup := make(map[string][2]int)
+	err := json.Unmarshal([]byte(jsonStr), &checkModelRequestRateLimitGroup)
+	if err != nil {
+		return err
+	}
+	for group, limits := range checkModelRequestRateLimitGroup {
+		if limits[0] < 0 || limits[1] < 1 {
+			return fmt.Errorf("group %s has negative rate limit values: [%d, %d]", group, limits[0], limits[1])
+		}
+	}
+
+	return nil
+}

+ 9 - 4
web/src/components/RateLimitSetting.js

@@ -13,6 +13,7 @@ const RateLimitSetting = () => {
     ModelRequestRateLimitCount: 0,
     ModelRequestRateLimitSuccessCount: 1000,
     ModelRequestRateLimitDurationMinutes: 1,
+    ModelRequestRateLimitGroup: '',
   });
 
   let [loading, setLoading] = useState(false);
@@ -23,10 +24,14 @@ const RateLimitSetting = () => {
     if (success) {
       let newInputs = {};
       data.forEach((item) => {
-        if (item.key.endsWith('Enabled')) {
-          newInputs[item.key] = item.value === 'true' ? true : false;
-        } else {
-          newInputs[item.key] = item.value;
+      if (item.key === 'ModelRequestRateLimitGroup') {
+        item.value = JSON.stringify(JSON.parse(item.value), null, 2);
+      }
+
+      if (item.key.endsWith('Enabled')) {
+        newInputs[item.key] = item.value === 'true' ? true : false;
+      } else {
+        newInputs[item.key] = item.value;
         }
       });
 

+ 44 - 0
web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js

@@ -6,6 +6,7 @@ import {
   showError,
   showSuccess,
   showWarning,
+  verifyJSON,
 } from '../../../helpers';
 import { useTranslation } from 'react-i18next';
 
@@ -18,6 +19,7 @@ export default function RequestRateLimit(props) {
     ModelRequestRateLimitCount: -1,
     ModelRequestRateLimitSuccessCount: 1000,
     ModelRequestRateLimitDurationMinutes: 1,
+    ModelRequestRateLimitGroup: '',
   });
   const refForm = useRef();
   const [inputsRow, setInputsRow] = useState(inputs);
@@ -46,6 +48,13 @@ export default function RequestRateLimit(props) {
           if (res.includes(undefined))
             return showError(t('部分保存失败,请重试'));
         }
+
+      for (let i = 0; i < res.length; i++) {
+        if (!res[i].data.success) {
+          return showError(res[i].data.message);
+        }
+      }
+
         showSuccess(t('保存成功'));
         props.refresh();
       })
@@ -147,6 +156,41 @@ export default function RequestRateLimit(props) {
                 />
               </Col>
             </Row>
+            <Row>
+              <Col xs={24} sm={16}>
+                <Form.TextArea
+                  label={t('分组速率限制')}
+                  placeholder={t(
+                    '{\n  "default": [200, 100],\n  "vip": [0, 1000]\n}',
+                  )}
+                  field={'ModelRequestRateLimitGroup'}
+                autosize={{ minRows: 5, maxRows: 15 }}
+                trigger='blur'
+                        stopValidateWithError
+                rules={[
+                  {
+                  validator: (rule, value) => verifyJSON(value),
+                  message: t('不是合法的 JSON 字符串'),
+                  },
+                ]}
+                  extraText={
+                    <div>
+                      <p style={{ marginBottom: -15 }}>{t('说明:')}</p>
+                      <ul>
+                        <li>{t('使用 JSON 对象格式,格式为:{"组名": [最多请求次数, 最多请求完成次数]}')}</li>
+                      <li>{t('示例:{"default": [200, 100], "vip": [0, 1000]}。')}</li>
+                      <li>{t('[最多请求次数]必须大于等于0,[最多请求完成次数]必须大于等于1。')}</li>
+                        <li>{t('分组速率配置优先级高于全局速率限制。')}</li>
+                        <li>{t('限制周期统一使用上方配置的“限制周期”值。')}</li>
+                      </ul>
+                    </div>
+                  }
+                  onChange={(value) => {
+                    setInputs({ ...inputs, ModelRequestRateLimitGroup: value });
+                  }}
+                />
+              </Col>
+            </Row>
             <Row>
               <Button size='default' onClick={onSubmit}>
                 {t('保存模型速率限制')}