فهرست منبع

fix: harden token search with pagination, rate limiting and input validation

- Add configurable per-user token creation limit (max_user_tokens)
- Sanitize search input patterns to prevent expensive queries
- Add per-user search rate limiting (by user ID)
- Add pagination to search endpoint with strict page size cap
- Skip empty search fields instead of matching nothing
- Hide internal errors from API responses
- Fix Interface2String float64 formatting causing config parse failures
- Add float-string fallback in config system for int/uint fields
CaIon 6 روز پیش
والد
کامیت
3e1be18310

+ 4 - 0
common/constants.go

@@ -175,6 +175,10 @@ var (
 
 	DownloadRateLimitNum            = 10
 	DownloadRateLimitDuration int64 = 60
+
+	// Per-user search rate limit (applies after authentication, keyed by user ID)
+	SearchRateLimitNum            = 10
+	SearchRateLimitDuration int64 = 60
 )
 
 var RateLimitKeyExpirationDuration = 20 * time.Minute

+ 1 - 1
common/utils.go

@@ -192,7 +192,7 @@ func Interface2String(inter interface{}) string {
 	case int:
 		return fmt.Sprintf("%d", inter.(int))
 	case float64:
-		return fmt.Sprintf("%f", inter.(float64))
+		return strconv.FormatFloat(inter.(float64), 'f', -1, 64)
 	case bool:
 		if inter.(bool) {
 			return "true"

+ 22 - 6
controller/token.go

@@ -8,6 +8,7 @@ import (
 
 	"github.com/QuantumNous/new-api/common"
 	"github.com/QuantumNous/new-api/model"
+	"github.com/QuantumNous/new-api/setting/operation_setting"
 
 	"github.com/gin-gonic/gin"
 )
@@ -31,16 +32,17 @@ func SearchTokens(c *gin.Context) {
 	userId := c.GetInt("id")
 	keyword := c.Query("keyword")
 	token := c.Query("token")
-	tokens, err := model.SearchUserTokens(userId, keyword, token)
+
+	pageInfo := common.GetPageQuery(c)
+
+	tokens, total, err := model.SearchUserTokens(userId, keyword, token, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
 	if err != nil {
 		common.ApiError(c, err)
 		return
 	}
-	c.JSON(http.StatusOK, gin.H{
-		"success": true,
-		"message": "",
-		"data":    tokens,
-	})
+	pageInfo.SetTotal(int(total))
+	pageInfo.SetItems(tokens)
+	common.ApiSuccess(c, pageInfo)
 	return
 }
 
@@ -168,6 +170,20 @@ func AddToken(c *gin.Context) {
 			return
 		}
 	}
+	// 检查用户令牌数量是否已达上限
+	maxTokens := operation_setting.GetMaxUserTokens()
+	count, err := model.CountUserTokens(c.GetInt("id"))
+	if err != nil {
+		common.ApiError(c, err)
+		return
+	}
+	if int(count) >= maxTokens {
+		c.JSON(http.StatusOK, gin.H{
+			"success": false,
+			"message": fmt.Sprintf("已达到最大令牌数量限制 (%d)", maxTokens),
+		})
+		return
+	}
 	key, err := common.GenerateKey()
 	if err != nil {
 		c.JSON(http.StatusOK, gin.H{

+ 85 - 0
middleware/rate-limit.go

@@ -115,3 +115,88 @@ func DownloadRateLimit() func(c *gin.Context) {
 func UploadRateLimit() func(c *gin.Context) {
 	return rateLimitFactory(common.UploadRateLimitNum, common.UploadRateLimitDuration, "UP")
 }
+
+// userRateLimitFactory creates a rate limiter keyed by authenticated user ID
+// instead of client IP, making it resistant to proxy rotation attacks.
+// Must be used AFTER authentication middleware (UserAuth).
+func userRateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gin.Context) {
+	if common.RedisEnabled {
+		return func(c *gin.Context) {
+			userId := c.GetInt("id")
+			if userId == 0 {
+				c.Status(http.StatusUnauthorized)
+				c.Abort()
+				return
+			}
+			key := fmt.Sprintf("rateLimit:%s:user:%d", mark, userId)
+			userRedisRateLimiter(c, maxRequestNum, duration, key)
+		}
+	}
+	// It's safe to call multi times.
+	inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration)
+	return func(c *gin.Context) {
+		userId := c.GetInt("id")
+		if userId == 0 {
+			c.Status(http.StatusUnauthorized)
+			c.Abort()
+			return
+		}
+		key := fmt.Sprintf("%s:user:%d", mark, userId)
+		if !inMemoryRateLimiter.Request(key, maxRequestNum, duration) {
+			c.Status(http.StatusTooManyRequests)
+			c.Abort()
+			return
+		}
+	}
+}
+
+// userRedisRateLimiter is like redisRateLimiter but accepts a pre-built key
+// (to support user-ID-based keys).
+func userRedisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, key string) {
+	ctx := context.Background()
+	rdb := common.RDB
+	listLength, err := rdb.LLen(ctx, key).Result()
+	if err != nil {
+		fmt.Println(err.Error())
+		c.Status(http.StatusInternalServerError)
+		c.Abort()
+		return
+	}
+	if listLength < int64(maxRequestNum) {
+		rdb.LPush(ctx, key, time.Now().Format(timeFormat))
+		rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
+	} else {
+		oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result()
+		oldTime, err := time.Parse(timeFormat, oldTimeStr)
+		if err != nil {
+			fmt.Println(err)
+			c.Status(http.StatusInternalServerError)
+			c.Abort()
+			return
+		}
+		nowTimeStr := time.Now().Format(timeFormat)
+		nowTime, err := time.Parse(timeFormat, nowTimeStr)
+		if err != nil {
+			fmt.Println(err)
+			c.Status(http.StatusInternalServerError)
+			c.Abort()
+			return
+		}
+		if int64(nowTime.Sub(oldTime).Seconds()) < duration {
+			rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
+			c.Status(http.StatusTooManyRequests)
+			c.Abort()
+			return
+		} else {
+			rdb.LPush(ctx, key, time.Now().Format(timeFormat))
+			rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1))
+			rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
+		}
+	}
+}
+
+// SearchRateLimit returns a per-user rate limiter for search endpoints.
+// 10 requests per 60 seconds per user (by user ID, not IP).
+func SearchRateLimit() func(c *gin.Context) {
+	return userRateLimitFactory(common.SearchRateLimitNum, common.SearchRateLimitDuration, "SR")
+}

+ 95 - 3
model/token.go

@@ -6,6 +6,7 @@ import (
 	"strings"
 
 	"github.com/QuantumNous/new-api/common"
+	"github.com/QuantumNous/new-api/setting/operation_setting"
 	"github.com/bytedance/gopkg/util/gopool"
 	"gorm.io/gorm"
 )
@@ -63,12 +64,103 @@ func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) {
 	return tokens, err
 }
 
-func SearchUserTokens(userId int, keyword string, token string) (tokens []*Token, err error) {
+// sanitizeLikePattern 校验并清洗用户输入的 LIKE 搜索模式。
+// 规则:
+//  1. 转义 _ 和 \(不允许 _ 作通配符)
+//  2. 连续的 % 合并为单个 %
+//  3. 最多允许 2 个 %
+//  4. 含 % 时(模糊搜索),去掉 % 后关键词长度必须 >= 2
+//  5. 不含 % 时按精确匹配
+func sanitizeLikePattern(input string) (string, error) {
+	// 1. 转义 \ 和 _
+	input = strings.ReplaceAll(input, `\`, `\\`)
+	input = strings.ReplaceAll(input, `_`, `\_`)
+
+	// 2. 连续的 % 直接拒绝
+	if strings.Contains(input, "%%") {
+		return "", errors.New("搜索模式中不允许包含连续的 % 通配符")
+	}
+
+	// 3. 统计 % 数量,不得超过 2
+	count := strings.Count(input, "%")
+	if count > 2 {
+		return "", errors.New("搜索模式中最多允许包含 2 个 % 通配符")
+	}
+
+	// 4. 含 % 时,去掉 % 后关键词长度必须 >= 2
+	if count > 0 {
+		stripped := strings.ReplaceAll(input, "%", "")
+		if len(stripped) < 2 {
+			return "", errors.New("使用模糊搜索时,关键词长度至少为 2 个字符")
+		}
+		return input, nil
+	}
+
+	// 5. 无 % 时,精确全匹配
+	return input, nil
+}
+
+const searchHardLimit = 100
+
+func SearchUserTokens(userId int, keyword string, token string, offset int, limit int) (tokens []*Token, total int64, err error) {
+	// model 层强制截断
+	if limit <= 0 || limit > searchHardLimit {
+		limit = searchHardLimit
+	}
+	if offset < 0 {
+		offset = 0
+	}
+
 	if token != "" {
 		token = strings.Trim(token, "sk-")
 	}
-	err = DB.Where("user_id = ?", userId).Where("name LIKE ?", "%"+keyword+"%").Where(commonKeyCol+" LIKE ?", "%"+token+"%").Find(&tokens).Error
-	return tokens, err
+
+	// 超量用户(令牌数超过上限)只允许精确搜索,禁止模糊搜索
+	maxTokens := operation_setting.GetMaxUserTokens()
+	hasFuzzy := strings.Contains(keyword, "%") || strings.Contains(token, "%")
+	if hasFuzzy {
+		count, err := CountUserTokens(userId)
+		if err != nil {
+			common.SysLog("failed to count user tokens: " + err.Error())
+			return nil, 0, errors.New("获取令牌数量失败")
+		}
+		if int(count) > maxTokens {
+			return nil, 0, errors.New("令牌数量超过上限,仅允许精确搜索,请勿使用 % 通配符")
+		}
+	}
+
+	baseQuery := DB.Model(&Token{}).Where("user_id = ?", userId)
+
+	// 非空才加 LIKE 条件,空则跳过(不过滤该字段)
+	if keyword != "" {
+		keywordPattern, err := sanitizeLikePattern(keyword)
+		if err != nil {
+			return nil, 0, err
+		}
+		baseQuery = baseQuery.Where("name LIKE ? ESCAPE '\\'", keywordPattern)
+	}
+	if token != "" {
+		tokenPattern, err := sanitizeLikePattern(token)
+		if err != nil {
+			return nil, 0, err
+		}
+		baseQuery = baseQuery.Where(commonKeyCol+" LIKE ? ESCAPE '\\'", tokenPattern)
+	}
+
+	// 先查匹配总数(用于分页,受 maxTokens 上限保护,避免全表 COUNT)
+	err = baseQuery.Limit(maxTokens).Count(&total).Error
+	if err != nil {
+		common.SysError("failed to count search tokens: " + err.Error())
+		return nil, 0, errors.New("搜索令牌失败")
+	}
+
+	// 再分页查数据
+	err = baseQuery.Order("id desc").Offset(offset).Limit(limit).Find(&tokens).Error
+	if err != nil {
+		common.SysError("failed to search tokens: " + err.Error())
+		return nil, 0, errors.New("搜索令牌失败")
+	}
+	return tokens, total, nil
 }
 
 func ValidateUserToken(key string) (token *Token, err error) {

+ 1 - 1
router/api-router.go

@@ -186,7 +186,7 @@ func SetApiRouter(router *gin.Engine) {
 		tokenRoute.Use(middleware.UserAuth())
 		{
 			tokenRoute.GET("/", controller.GetAllTokens)
-			tokenRoute.GET("/search", controller.SearchTokens)
+			tokenRoute.GET("/search", middleware.SearchRateLimit(), controller.SearchTokens)
 			tokenRoute.GET("/:id", controller.GetToken)
 			tokenRoute.POST("/", controller.AddToken)
 			tokenRoute.PUT("/", controller.UpdateToken)

+ 12 - 2
setting/config/config.go

@@ -212,13 +212,23 @@ func updateConfigFromMap(config interface{}, configMap map[string]string) error
 		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
 			intValue, err := strconv.ParseInt(strValue, 10, 64)
 			if err != nil {
-				continue
+				// 兼容 float 格式的字符串(如 "2.000000")
+				floatValue, fErr := strconv.ParseFloat(strValue, 64)
+				if fErr != nil {
+					continue
+				}
+				intValue = int64(floatValue)
 			}
 			field.SetInt(intValue)
 		case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
 			uintValue, err := strconv.ParseUint(strValue, 10, 64)
 			if err != nil {
-				continue
+				// 兼容 float 格式的字符串
+				floatValue, fErr := strconv.ParseFloat(strValue, 64)
+				if fErr != nil || floatValue < 0 {
+					continue
+				}
+				uintValue = uint64(floatValue)
 			}
 			field.SetUint(uintValue)
 		case reflect.Float32, reflect.Float64:

+ 28 - 0
setting/operation_setting/token_setting.go

@@ -0,0 +1,28 @@
+package operation_setting
+
+import "github.com/QuantumNous/new-api/setting/config"
+
+// TokenSetting 令牌相关配置
+type TokenSetting struct {
+	MaxUserTokens int `json:"max_user_tokens"` // 每用户最大令牌数量
+}
+
+// 默认配置
+var tokenSetting = TokenSetting{
+	MaxUserTokens: 1000, // 默认每用户最多 1000 个令牌
+}
+
+func init() {
+	// 注册到全局配置管理器
+	config.GlobalConfig.Register("token_setting", &tokenSetting)
+}
+
+// GetTokenSetting 获取令牌配置
+func GetTokenSetting() *TokenSetting {
+	return &tokenSetting
+}
+
+// GetMaxUserTokens 获取每用户最大令牌数量
+func GetMaxUserTokens() int {
+	return GetTokenSetting().MaxUserTokens
+}

+ 3 - 0
web/src/components/settings/OperationSetting.jsx

@@ -77,6 +77,9 @@ const OperationSetting = () => {
     'checkin_setting.enabled': false,
     'checkin_setting.min_quota': 1000,
     'checkin_setting.max_quota': 10000,
+
+    /* 令牌设置 */
+    'token_setting.max_user_tokens': 1000,
   });
 
   let [loading, setLoading] = useState(false);

+ 17 - 7
web/src/hooks/tokens/useTokensData.jsx

@@ -40,6 +40,7 @@ export const useTokensData = (openFluentNotification) => {
   const [tokenCount, setTokenCount] = useState(0);
   const [pageSize, setPageSize] = useState(ITEMS_PER_PAGE);
   const [searching, setSearching] = useState(false);
+  const [searchMode, setSearchMode] = useState(false); // 是否处于搜索结果视图
 
   // Selection state
   const [selectedKeys, setSelectedKeys] = useState([]);
@@ -91,6 +92,7 @@ export const useTokensData = (openFluentNotification) => {
   // Load tokens function
   const loadTokens = async (page = 1, size = pageSize) => {
     setLoading(true);
+    setSearchMode(false);
     const res = await API.get(`/api/token/?p=${page}&size=${size}`);
     const { success, message, data } = res.data;
     if (success) {
@@ -188,21 +190,21 @@ export const useTokensData = (openFluentNotification) => {
   };
 
   // Search tokens function
-  const searchTokens = async () => {
+  const searchTokens = async (page = 1, size = pageSize) => {
     const { searchKeyword, searchToken } = getFormValues();
     if (searchKeyword === '' && searchToken === '') {
+      setSearchMode(false);
       await loadTokens(1);
       return;
     }
     setSearching(true);
     const res = await API.get(
-      `/api/token/search?keyword=${searchKeyword}&token=${searchToken}`,
+      `/api/token/search?keyword=${encodeURIComponent(searchKeyword)}&token=${encodeURIComponent(searchToken)}&p=${page}&size=${size}`,
     );
     const { success, message, data } = res.data;
     if (success) {
-      setTokens(data);
-      setTokenCount(data.length);
-      setActivePage(1);
+      setSearchMode(true);
+      syncPageData(data);
     } else {
       showError(message);
     }
@@ -226,12 +228,20 @@ export const useTokensData = (openFluentNotification) => {
 
   // Page handlers
   const handlePageChange = (page) => {
-    loadTokens(page, pageSize).then();
+    if (searchMode) {
+      searchTokens(page, pageSize).then();
+    } else {
+      loadTokens(page, pageSize).then();
+    }
   };
 
   const handlePageSizeChange = async (size) => {
     setPageSize(size);
-    await loadTokens(1, size);
+    if (searchMode) {
+      await searchTokens(1, size);
+    } else {
+      await loadTokens(1, size);
+    }
   };
 
   // Row selection handlers

+ 14 - 0
web/src/pages/Setting/Operation/SettingsGeneral.jsx

@@ -56,6 +56,7 @@ export default function GeneralSettings(props) {
     DefaultCollapseSidebar: false,
     DemoSiteEnabled: false,
     SelfUseModeEnabled: false,
+    'token_setting.max_user_tokens': 1000,
   });
   const refForm = useRef();
   const [inputsRow, setInputsRow] = useState(inputs);
@@ -287,6 +288,19 @@ export default function GeneralSettings(props) {
                 />
               </Col>
             </Row>
+            <Row gutter={16}>
+              <Col xs={24} sm={12} md={8} lg={8} xl={8}>
+                <Form.InputNumber
+                  label={t('用户最大令牌数量')}
+                  field={'token_setting.max_user_tokens'}
+                  step={1}
+                  min={1}
+                  extraText={t('每个用户最多可创建的令牌数量,默认 1000,设置过大可能会影响性能')}
+                  placeholder={'1000'}
+                  onChange={handleFieldChange('token_setting.max_user_tokens')}
+                />
+              </Col>
+            </Row>
             <Row>
               <Button size='default' onClick={onSubmit}>
                 {t('保存通用设置')}