|
|
@@ -6,6 +6,7 @@ import (
|
|
|
"strings"
|
|
|
|
|
|
"github.com/QuantumNous/new-api/common"
|
|
|
+ "github.com/QuantumNous/new-api/setting/operation_setting"
|
|
|
"github.com/bytedance/gopkg/util/gopool"
|
|
|
"gorm.io/gorm"
|
|
|
)
|
|
|
@@ -63,12 +64,103 @@ func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) {
|
|
|
return tokens, err
|
|
|
}
|
|
|
|
|
|
-func SearchUserTokens(userId int, keyword string, token string) (tokens []*Token, err error) {
|
|
|
+// sanitizeLikePattern 校验并清洗用户输入的 LIKE 搜索模式。
|
|
|
+// 规则:
|
|
|
+// 1. 转义 _ 和 \(不允许 _ 作通配符)
|
|
|
+// 2. 连续的 % 合并为单个 %
|
|
|
+// 3. 最多允许 2 个 %
|
|
|
+// 4. 含 % 时(模糊搜索),去掉 % 后关键词长度必须 >= 2
|
|
|
+// 5. 不含 % 时按精确匹配
|
|
|
+func sanitizeLikePattern(input string) (string, error) {
|
|
|
+ // 1. 转义 \ 和 _
|
|
|
+ input = strings.ReplaceAll(input, `\`, `\\`)
|
|
|
+ input = strings.ReplaceAll(input, `_`, `\_`)
|
|
|
+
|
|
|
+ // 2. 连续的 % 直接拒绝
|
|
|
+ if strings.Contains(input, "%%") {
|
|
|
+ return "", errors.New("搜索模式中不允许包含连续的 % 通配符")
|
|
|
+ }
|
|
|
+
|
|
|
+ // 3. 统计 % 数量,不得超过 2
|
|
|
+ count := strings.Count(input, "%")
|
|
|
+ if count > 2 {
|
|
|
+ return "", errors.New("搜索模式中最多允许包含 2 个 % 通配符")
|
|
|
+ }
|
|
|
+
|
|
|
+ // 4. 含 % 时,去掉 % 后关键词长度必须 >= 2
|
|
|
+ if count > 0 {
|
|
|
+ stripped := strings.ReplaceAll(input, "%", "")
|
|
|
+ if len(stripped) < 2 {
|
|
|
+ return "", errors.New("使用模糊搜索时,关键词长度至少为 2 个字符")
|
|
|
+ }
|
|
|
+ return input, nil
|
|
|
+ }
|
|
|
+
|
|
|
+ // 5. 无 % 时,精确全匹配
|
|
|
+ return input, nil
|
|
|
+}
|
|
|
+
|
|
|
+const searchHardLimit = 100
|
|
|
+
|
|
|
+func SearchUserTokens(userId int, keyword string, token string, offset int, limit int) (tokens []*Token, total int64, err error) {
|
|
|
+ // model 层强制截断
|
|
|
+ if limit <= 0 || limit > searchHardLimit {
|
|
|
+ limit = searchHardLimit
|
|
|
+ }
|
|
|
+ if offset < 0 {
|
|
|
+ offset = 0
|
|
|
+ }
|
|
|
+
|
|
|
if token != "" {
|
|
|
token = strings.Trim(token, "sk-")
|
|
|
}
|
|
|
- err = DB.Where("user_id = ?", userId).Where("name LIKE ?", "%"+keyword+"%").Where(commonKeyCol+" LIKE ?", "%"+token+"%").Find(&tokens).Error
|
|
|
- return tokens, err
|
|
|
+
|
|
|
+ // 超量用户(令牌数超过上限)只允许精确搜索,禁止模糊搜索
|
|
|
+ maxTokens := operation_setting.GetMaxUserTokens()
|
|
|
+ hasFuzzy := strings.Contains(keyword, "%") || strings.Contains(token, "%")
|
|
|
+ if hasFuzzy {
|
|
|
+ count, err := CountUserTokens(userId)
|
|
|
+ if err != nil {
|
|
|
+ common.SysLog("failed to count user tokens: " + err.Error())
|
|
|
+ return nil, 0, errors.New("获取令牌数量失败")
|
|
|
+ }
|
|
|
+ if int(count) > maxTokens {
|
|
|
+ return nil, 0, errors.New("令牌数量超过上限,仅允许精确搜索,请勿使用 % 通配符")
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ baseQuery := DB.Model(&Token{}).Where("user_id = ?", userId)
|
|
|
+
|
|
|
+ // 非空才加 LIKE 条件,空则跳过(不过滤该字段)
|
|
|
+ if keyword != "" {
|
|
|
+ keywordPattern, err := sanitizeLikePattern(keyword)
|
|
|
+ if err != nil {
|
|
|
+ return nil, 0, err
|
|
|
+ }
|
|
|
+ baseQuery = baseQuery.Where("name LIKE ? ESCAPE '\\'", keywordPattern)
|
|
|
+ }
|
|
|
+ if token != "" {
|
|
|
+ tokenPattern, err := sanitizeLikePattern(token)
|
|
|
+ if err != nil {
|
|
|
+ return nil, 0, err
|
|
|
+ }
|
|
|
+ baseQuery = baseQuery.Where(commonKeyCol+" LIKE ? ESCAPE '\\'", tokenPattern)
|
|
|
+ }
|
|
|
+
|
|
|
+ // 先查匹配总数(用于分页,受 maxTokens 上限保护,避免全表 COUNT)
|
|
|
+ err = baseQuery.Limit(maxTokens).Count(&total).Error
|
|
|
+ if err != nil {
|
|
|
+ common.SysError("failed to count search tokens: " + err.Error())
|
|
|
+ return nil, 0, errors.New("搜索令牌失败")
|
|
|
+ }
|
|
|
+
|
|
|
+ // 再分页查数据
|
|
|
+ err = baseQuery.Order("id desc").Offset(offset).Limit(limit).Find(&tokens).Error
|
|
|
+ if err != nil {
|
|
|
+ common.SysError("failed to search tokens: " + err.Error())
|
|
|
+ return nil, 0, errors.New("搜索令牌失败")
|
|
|
+ }
|
|
|
+ return tokens, total, nil
|
|
|
}
|
|
|
|
|
|
func ValidateUserToken(key string) (token *Token, err error) {
|