Просмотр исходного кода

feat: Add model request rate limiting functionality

[email protected] 10 месяцев назад
Родитель
Сommit
83a37e4653

+ 172 - 0
middleware/model-rate-limit.go

@@ -0,0 +1,172 @@
+package middleware
+
+import (
+	"context"
+	"fmt"
+	"net/http"
+	"one-api/common"
+	"one-api/setting"
+	"strconv"
+	"time"
+
+	"github.com/gin-gonic/gin"
+	"github.com/go-redis/redis/v8"
+)
+
+const (
+	ModelRequestRateLimitCountMark        = "MRRL"
+	ModelRequestRateLimitSuccessCountMark = "MRRLS"
+)
+
+// 检查Redis中的请求限制
+func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, maxCount int, duration int64) (bool, error) {
+	// 如果maxCount为0,表示不限制
+	if maxCount == 0 {
+		return true, nil
+	}
+
+	// 获取当前计数
+	length, err := rdb.LLen(ctx, key).Result()
+	if err != nil {
+		return false, err
+	}
+
+	// 如果未达到限制,允许请求
+	if length < int64(maxCount) {
+		return true, nil
+	}
+
+	// 检查时间窗口
+	oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result()
+	oldTime, err := time.Parse(timeFormat, oldTimeStr)
+	if err != nil {
+		return false, err
+	}
+
+	nowTimeStr := time.Now().Format(timeFormat)
+	nowTime, err := time.Parse(timeFormat, nowTimeStr)
+	if err != nil {
+		return false, err
+	}
+	// 如果在时间窗口内已达到限制,拒绝请求
+	subTime := nowTime.Sub(oldTime).Seconds()
+	if int64(subTime) < duration {
+		rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
+		return false, nil
+	}
+
+	return true, nil
+}
+
+// 记录Redis请求
+func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxCount int) {
+	// 如果maxCount为0,不记录请求
+	if maxCount == 0 {
+		return
+	}
+
+	now := time.Now().Format(timeFormat)
+	rdb.LPush(ctx, key, now)
+	rdb.LTrim(ctx, key, 0, int64(maxCount-1))
+	rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
+}
+
+// Redis限流处理器
+func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc {
+	return func(c *gin.Context) {
+		userId := strconv.Itoa(c.GetInt("id"))
+		ctx := context.Background()
+		rdb := common.RDB
+
+		// 1. 检查总请求数限制(当totalMaxCount为0时会自动跳过)
+		totalKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitCountMark, userId)
+		allowed, err := checkRedisRateLimit(ctx, rdb, totalKey, totalMaxCount, duration)
+		if err != nil {
+			fmt.Println("检查总请求数限制失败:", err.Error())
+			abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
+			return
+		}
+		if !allowed {
+			abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount))
+		}
+
+		// 2. 检查成功请求数限制
+		successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId)
+		allowed, err = checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration)
+		if err != nil {
+			fmt.Println("检查成功请求数限制失败:", err.Error())
+			abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed")
+			return
+		}
+		if !allowed {
+			abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到请求数限制:%d分钟内最多请求%d次", setting.ModelRequestRateLimitDurationMinutes, successMaxCount))
+			return
+		}
+
+		// 3. 记录总请求(当totalMaxCount为0时会自动跳过)
+		recordRedisRequest(ctx, rdb, totalKey, totalMaxCount)
+
+		// 4. 处理请求
+		c.Next()
+
+		// 5. 如果请求成功,记录成功请求
+		if c.Writer.Status() < 400 {
+			recordRedisRequest(ctx, rdb, successKey, successMaxCount)
+		}
+	}
+}
+
+// 内存限流处理器
+func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc {
+	inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration)
+
+	return func(c *gin.Context) {
+		userId := strconv.Itoa(c.GetInt("id"))
+		totalKey := ModelRequestRateLimitCountMark + userId
+		successKey := ModelRequestRateLimitSuccessCountMark + userId
+
+		// 1. 检查总请求数限制(当totalMaxCount为0时跳过)
+		if totalMaxCount > 0 && !inMemoryRateLimiter.Request(totalKey, totalMaxCount, duration) {
+			c.Status(http.StatusTooManyRequests)
+			c.Abort()
+			return
+		}
+
+		// 2. 检查成功请求数限制
+		// 使用一个临时key来检查限制,这样可以避免实际记录
+		checkKey := successKey + "_check"
+		if !inMemoryRateLimiter.Request(checkKey, successMaxCount, duration) {
+			c.Status(http.StatusTooManyRequests)
+			c.Abort()
+			return
+		}
+
+		// 3. 处理请求
+		c.Next()
+
+		// 4. 如果请求成功,记录到实际的成功请求计数中
+		if c.Writer.Status() < 400 {
+			inMemoryRateLimiter.Request(successKey, successMaxCount, duration)
+		}
+	}
+}
+
+// ModelRequestRateLimit 模型请求限流中间件
+func ModelRequestRateLimit() func(c *gin.Context) {
+	// 如果未启用限流,直接放行
+	if !setting.ModelRequestRateLimitEnabled {
+		return defNext
+	}
+
+	// 计算限流参数
+	duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60)
+	totalMaxCount := setting.ModelRequestRateLimitCount
+	successMaxCount := setting.ModelRequestRateLimitSuccessCount
+
+	// 根据存储类型选择限流处理器
+	if common.RedisEnabled {
+		return redisRateLimitHandler(duration, totalMaxCount, successMaxCount)
+	} else {
+		return memoryRateLimitHandler(duration, totalMaxCount, successMaxCount)
+	}
+}

+ 13 - 0
model/option.go

@@ -85,6 +85,9 @@ func InitOptionMap() {
 	common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee)
 	common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
 	common.OptionMap["ShouldPreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
+	common.OptionMap["ModelRequestRateLimitCount"] = strconv.Itoa(setting.ModelRequestRateLimitCount)
+	common.OptionMap["ModelRequestRateLimitDurationMinutes"] = strconv.Itoa(setting.ModelRequestRateLimitDurationMinutes)
+	common.OptionMap["ModelRequestRateLimitSuccessCount"] = strconv.Itoa(setting.ModelRequestRateLimitSuccessCount)
 	common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
 	common.OptionMap["ModelPrice"] = common.ModelPrice2JSONString()
 	common.OptionMap["GroupRatio"] = setting.GroupRatio2JSONString()
@@ -105,6 +108,7 @@ func InitOptionMap() {
 	common.OptionMap["MjActionCheckSuccessEnabled"] = strconv.FormatBool(setting.MjActionCheckSuccessEnabled)
 	common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(setting.CheckSensitiveEnabled)
 	common.OptionMap["DemoSiteEnabled"] = strconv.FormatBool(setting.DemoSiteEnabled)
+	common.OptionMap["ModelRequestRateLimitEnabled"] = strconv.FormatBool(setting.ModelRequestRateLimitEnabled)
 	common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(setting.CheckSensitiveOnPromptEnabled)
 	//common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled)
 	common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(setting.StopOnSensitiveEnabled)
@@ -226,6 +230,9 @@ func updateOptionMap(key string, value string) (err error) {
 			setting.DemoSiteEnabled = boolValue
 		case "CheckSensitiveOnPromptEnabled":
 			setting.CheckSensitiveOnPromptEnabled = boolValue
+		case "ModelRequestRateLimitEnabled":
+			setting.ModelRequestRateLimitEnabled = boolValue
+
 		//case "CheckSensitiveOnCompletionEnabled":
 		//	constant.CheckSensitiveOnCompletionEnabled = boolValue
 		case "StopOnSensitiveEnabled":
@@ -308,6 +315,12 @@ func updateOptionMap(key string, value string) (err error) {
 		common.QuotaRemindThreshold, _ = strconv.Atoi(value)
 	case "ShouldPreConsumedQuota":
 		common.PreConsumedQuota, _ = strconv.Atoi(value)
+	case "ModelRequestRateLimitCount":
+		setting.ModelRequestRateLimitCount, _ = strconv.Atoi(value)
+	case "ModelRequestRateLimitDurationMinutes":
+		setting.ModelRequestRateLimitDurationMinutes, _ = strconv.Atoi(value)
+	case "ModelRequestRateLimitSuccessCount":
+		setting.ModelRequestRateLimitSuccessCount, _ = strconv.Atoi(value)
 	case "RetryTimes":
 		common.RetryTimes, _ = strconv.Atoi(value)
 	case "DataExportInterval":

+ 1 - 0
router/relay-router.go

@@ -24,6 +24,7 @@ func SetRelayRouter(router *gin.Engine) {
 	}
 	relayV1Router := router.Group("/v1")
 	relayV1Router.Use(middleware.TokenAuth())
+	relayV1Router.Use(middleware.ModelRequestRateLimit())
 	{
 		// WebSocket 路由
 		wsRouter := relayV1Router.Group("")

+ 6 - 0
setting/rate_limit.go

@@ -0,0 +1,6 @@
+package setting
+
+var ModelRequestRateLimitEnabled = false
+var ModelRequestRateLimitDurationMinutes = 1
+var ModelRequestRateLimitCount = 0
+var ModelRequestRateLimitSuccessCount = 1000

+ 80 - 0
web/src/components/RateLimitSetting.js

@@ -0,0 +1,80 @@
+import React, { useEffect, useState } from 'react';
+import { Card, Spin, Tabs } from '@douyinfe/semi-ui';
+import SettingsGeneral from '../pages/Setting/Operation/SettingsGeneral.js';
+import SettingsDrawing from '../pages/Setting/Operation/SettingsDrawing.js';
+import SettingsSensitiveWords from '../pages/Setting/Operation/SettingsSensitiveWords.js';
+import SettingsLog from '../pages/Setting/Operation/SettingsLog.js';
+import SettingsDataDashboard from '../pages/Setting/Operation/SettingsDataDashboard.js';
+import SettingsMonitoring from '../pages/Setting/Operation/SettingsMonitoring.js';
+import SettingsCreditLimit from '../pages/Setting/Operation/SettingsCreditLimit.js';
+import SettingsMagnification from '../pages/Setting/Operation/SettingsMagnification.js';
+import ModelSettingsVisualEditor from '../pages/Setting/Operation/ModelSettingsVisualEditor.js';
+import GroupRatioSettings from '../pages/Setting/Operation/GroupRatioSettings.js';
+import ModelRatioSettings from '../pages/Setting/Operation/ModelRatioSettings.js';
+
+
+import { API, showError, showSuccess } from '../helpers';
+import SettingsChats from '../pages/Setting/Operation/SettingsChats.js';
+import { useTranslation } from 'react-i18next';
+import RequestRateLimit from '../pages/Setting/RateLimit/SettingsRequestRateLimit.js';
+
+const RateLimitSetting = () => {
+  const { t } = useTranslation();
+  let [inputs, setInputs] = useState({
+    ModelRequestRateLimitEnabled: false,
+    ModelRequestRateLimitCount: 0,
+    ModelRequestRateLimitSuccessCount: 1000,
+    ModelRequestRateLimitDurationMinutes: 1,
+  });
+
+  let [loading, setLoading] = useState(false);
+
+  const getOptions = async () => {
+    const res = await API.get('/api/option/');
+    const { success, message, data } = res.data;
+    if (success) {
+      let newInputs = {};
+      data.forEach((item) => {
+        if (
+          item.key.endsWith('Enabled')
+        ) {
+          newInputs[item.key] = item.value === 'true' ? true : false;
+        } else {
+          newInputs[item.key] = item.value;
+        }
+      });
+
+      setInputs(newInputs);
+    } else {
+      showError(message);
+    }
+  };
+  async function onRefresh() {
+    try {
+      setLoading(true);
+      await getOptions();
+      // showSuccess('刷新成功');
+    } catch (error) {
+      showError('刷新失败');
+    } finally {
+      setLoading(false);
+    }
+  }
+
+  useEffect(() => {
+    onRefresh();
+  }, []);
+
+  return (
+    <>
+      <Spin spinning={loading} size='large'>
+        {/* AI请求速率限制 */}
+        <Card style={{ marginTop: '10px' }}>
+          <RequestRateLimit options={inputs} refresh={onRefresh} />
+        </Card>
+      </Spin>
+    </>
+  );
+};
+
+export default RateLimitSetting;

+ 159 - 0
web/src/pages/Setting/RateLimit/SettingsRequestRateLimit.js

@@ -0,0 +1,159 @@
+import React, { useEffect, useState, useRef } from 'react';
+import { Button, Col, Form, Row, Spin } from '@douyinfe/semi-ui';
+import {
+  compareObjects,
+  API,
+  showError,
+  showSuccess,
+  showWarning,
+} from '../../../helpers';
+import { useTranslation } from 'react-i18next';
+
+export default function RequestRateLimit(props) {
+  const { t } = useTranslation();
+
+  const [loading, setLoading] = useState(false);
+  const [inputs, setInputs] = useState({
+    ModelRequestRateLimitEnabled: false,
+    ModelRequestRateLimitCount: -1,
+    ModelRequestRateLimitSuccessCount: 1000,
+    ModelRequestRateLimitDurationMinutes: 1
+  });
+  const refForm = useRef();
+  const [inputsRow, setInputsRow] = useState(inputs);
+
+  function onSubmit() {
+    const updateArray = compareObjects(inputs, inputsRow);
+    if (!updateArray.length) return showWarning(t('你似乎并没有修改什么'));
+    const requestQueue = updateArray.map((item) => {
+      let value = '';
+      if (typeof inputs[item.key] === 'boolean') {
+        value = String(inputs[item.key]);
+      } else {
+        value = inputs[item.key];
+      }
+      return API.put('/api/option/', {
+        key: item.key,
+        value,
+      });
+    });
+    setLoading(true);
+    Promise.all(requestQueue)
+      .then((res) => {
+        if (requestQueue.length === 1) {
+          if (res.includes(undefined)) return;
+        } else if (requestQueue.length > 1) {
+          if (res.includes(undefined)) return showError(t('部分保存失败,请重试'));
+        }
+        showSuccess(t('保存成功'));
+        props.refresh();
+      })
+      .catch(() => {
+        showError(t('保存失败,请重试'));
+      })
+      .finally(() => {
+        setLoading(false);
+      });
+  }
+
+  useEffect(() => {
+    const currentInputs = {};
+    for (let key in props.options) {
+      if (Object.keys(inputs).includes(key)) {
+        currentInputs[key] = props.options[key];
+      }
+    }
+    setInputs(currentInputs);
+    setInputsRow(structuredClone(currentInputs));
+    refForm.current.setValues(currentInputs);
+  }, [props.options]);
+
+  return (
+    <>
+      <Spin spinning={loading}>
+        <Form
+          values={inputs}
+          getFormApi={(formAPI) => (refForm.current = formAPI)}
+          style={{ marginBottom: 15 }}
+        >
+          <Form.Section text={t('模型请求速率限制')}>
+            <Row gutter={16}>
+              <Col span={8}>
+                <Form.Switch
+                  field={'ModelRequestRateLimitEnabled'}
+                  label={t('启用用户模型请求速率限制(可能会影响高并发性能)')}
+                  size='default'
+                  checkedText='|'
+                  uncheckedText='〇'
+                  onChange={(value) => {
+                    setInputs({
+                      ...inputs,
+                      ModelRequestRateLimitEnabled: value,
+                    });
+                  }}
+                />
+              </Col>
+            </Row>
+            <Row>
+              <Col span={8}>
+                <Form.InputNumber
+                  label={t('限制周期')}
+                  step={1}
+                  min={0}
+                  suffix={t('分钟')}
+                  extraText={t('频率限制的周期(分钟)')}
+                  field={'ModelRequestRateLimitDurationMinutes'}
+                  onChange={(value) =>
+                    setInputs({
+                      ...inputs,
+                      ModelRequestRateLimitDurationMinutes: String(value),
+                    })
+                  }
+                />
+              </Col>
+            </Row>
+            <Row>
+              <Col span={8}>
+                <Form.InputNumber
+                  label={t('用户每周期最多请求次数')}
+                  step={1}
+                  min={0}
+                  suffix={t('次')}
+                  extraText={t('包括失败请求的次数,0代表不限制')}
+                  field={'ModelRequestRateLimitCount'}
+                  onChange={(value) =>
+                    setInputs({
+                      ...inputs,
+                      ModelRequestRateLimitCount: String(value),
+                    })
+                  }
+                />
+              </Col>
+              <Col span={8}>
+                <Form.InputNumber
+                  label={t('用户每周期最多请求完成次数')}
+                  step={1}
+                  min={1}
+                  suffix={t('次')}
+                  extraText={t('只包括请求成功的次数')}
+                  field={'ModelRequestRateLimitSuccessCount'}
+                  onChange={(value) =>
+                    setInputs({
+                      ...inputs,
+                      ModelRequestRateLimitSuccessCount: String(value),
+                    })
+                  }
+                />
+              </Col>
+            </Row>
+            <Row>
+              <Button size='default' onClick={onSubmit}>
+                {t('保存模型速率限制')}
+              </Button>
+            </Row>
+          </Form.Section>
+        </Form>
+      </Spin>
+    </>
+  );
+}

+ 6 - 0
web/src/pages/Setting/index.js

@@ -8,6 +8,7 @@ import { isRoot } from '../../helpers';
 import OtherSetting from '../../components/OtherSetting';
 import PersonalSetting from '../../components/PersonalSetting';
 import OperationSetting from '../../components/OperationSetting';
+import RateLimitSetting from '../../components/RateLimitSetting.js';
 
 const Setting = () => {
   const { t } = useTranslation();
@@ -28,6 +29,11 @@ const Setting = () => {
       content: <OperationSetting />,
       itemKey: 'operation',
     });
+    panes.push({
+      tab: t('速率限制设置'),
+      content: <RateLimitSetting />,
+      itemKey: 'ratelimit',
+    });
     panes.push({
       tab: t('系统设置'),
       content: <SystemSetting />,