override.go 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597
  1. package common
  2. import (
  3. "fmt"
  4. "regexp"
  5. "strconv"
  6. "strings"
  7. "github.com/QuantumNous/new-api/common"
  8. "github.com/tidwall/gjson"
  9. "github.com/tidwall/sjson"
  10. )
  11. var negativeIndexRegexp = regexp.MustCompile(`\.(-\d+)`)
  12. type ConditionOperation struct {
  13. Path string `json:"path"` // JSON路径
  14. Mode string `json:"mode"` // full, prefix, suffix, contains, gt, gte, lt, lte
  15. Value interface{} `json:"value"` // 匹配的值
  16. Invert bool `json:"invert"` // 反选功能,true表示取反结果
  17. PassMissingKey bool `json:"pass_missing_key"` // 未获取到json key时的行为
  18. }
  19. type ParamOperation struct {
  20. Path string `json:"path"`
  21. Mode string `json:"mode"` // delete, set, move, copy, prepend, append, trim_prefix, trim_suffix, ensure_prefix, ensure_suffix, trim_space, to_lower, to_upper, replace, regex_replace
  22. Value interface{} `json:"value"`
  23. KeepOrigin bool `json:"keep_origin"`
  24. From string `json:"from,omitempty"`
  25. To string `json:"to,omitempty"`
  26. Conditions []ConditionOperation `json:"conditions,omitempty"` // 条件列表
  27. Logic string `json:"logic,omitempty"` // AND, OR (默认OR)
  28. }
  29. func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}, conditionContext map[string]interface{}) ([]byte, error) {
  30. if len(paramOverride) == 0 {
  31. return jsonData, nil
  32. }
  33. // 尝试断言为操作格式
  34. if operations, ok := tryParseOperations(paramOverride); ok {
  35. // 使用新方法
  36. result, err := applyOperations(string(jsonData), operations, conditionContext)
  37. return []byte(result), err
  38. }
  39. // 直接使用旧方法
  40. return applyOperationsLegacy(jsonData, paramOverride)
  41. }
  42. func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation, bool) {
  43. // 检查是否包含 "operations" 字段
  44. if opsValue, exists := paramOverride["operations"]; exists {
  45. if opsSlice, ok := opsValue.([]interface{}); ok {
  46. var operations []ParamOperation
  47. for _, op := range opsSlice {
  48. if opMap, ok := op.(map[string]interface{}); ok {
  49. operation := ParamOperation{}
  50. // 断言必要字段
  51. if path, ok := opMap["path"].(string); ok {
  52. operation.Path = path
  53. }
  54. if mode, ok := opMap["mode"].(string); ok {
  55. operation.Mode = mode
  56. } else {
  57. return nil, false // mode 是必需的
  58. }
  59. // 可选字段
  60. if value, exists := opMap["value"]; exists {
  61. operation.Value = value
  62. }
  63. if keepOrigin, ok := opMap["keep_origin"].(bool); ok {
  64. operation.KeepOrigin = keepOrigin
  65. }
  66. if from, ok := opMap["from"].(string); ok {
  67. operation.From = from
  68. }
  69. if to, ok := opMap["to"].(string); ok {
  70. operation.To = to
  71. }
  72. if logic, ok := opMap["logic"].(string); ok {
  73. operation.Logic = logic
  74. } else {
  75. operation.Logic = "OR" // 默认为OR
  76. }
  77. // 解析条件
  78. if conditions, exists := opMap["conditions"]; exists {
  79. if condSlice, ok := conditions.([]interface{}); ok {
  80. for _, cond := range condSlice {
  81. if condMap, ok := cond.(map[string]interface{}); ok {
  82. condition := ConditionOperation{}
  83. if path, ok := condMap["path"].(string); ok {
  84. condition.Path = path
  85. }
  86. if mode, ok := condMap["mode"].(string); ok {
  87. condition.Mode = mode
  88. }
  89. if value, ok := condMap["value"]; ok {
  90. condition.Value = value
  91. }
  92. if invert, ok := condMap["invert"].(bool); ok {
  93. condition.Invert = invert
  94. }
  95. if passMissingKey, ok := condMap["pass_missing_key"].(bool); ok {
  96. condition.PassMissingKey = passMissingKey
  97. }
  98. operation.Conditions = append(operation.Conditions, condition)
  99. }
  100. }
  101. }
  102. }
  103. operations = append(operations, operation)
  104. } else {
  105. return nil, false
  106. }
  107. }
  108. return operations, true
  109. }
  110. }
  111. return nil, false
  112. }
  113. func checkConditions(jsonStr, contextJSON string, conditions []ConditionOperation, logic string) (bool, error) {
  114. if len(conditions) == 0 {
  115. return true, nil // 没有条件,直接通过
  116. }
  117. results := make([]bool, len(conditions))
  118. for i, condition := range conditions {
  119. result, err := checkSingleCondition(jsonStr, contextJSON, condition)
  120. if err != nil {
  121. return false, err
  122. }
  123. results[i] = result
  124. }
  125. if strings.ToUpper(logic) == "AND" {
  126. for _, result := range results {
  127. if !result {
  128. return false, nil
  129. }
  130. }
  131. return true, nil
  132. } else {
  133. for _, result := range results {
  134. if result {
  135. return true, nil
  136. }
  137. }
  138. return false, nil
  139. }
  140. }
  141. func checkSingleCondition(jsonStr, contextJSON string, condition ConditionOperation) (bool, error) {
  142. // 处理负数索引
  143. path := processNegativeIndex(jsonStr, condition.Path)
  144. value := gjson.Get(jsonStr, path)
  145. if !value.Exists() && contextJSON != "" {
  146. value = gjson.Get(contextJSON, condition.Path)
  147. }
  148. if !value.Exists() {
  149. if condition.PassMissingKey {
  150. return true, nil
  151. }
  152. return false, nil
  153. }
  154. // 利用gjson的类型解析
  155. targetBytes, err := common.Marshal(condition.Value)
  156. if err != nil {
  157. return false, fmt.Errorf("failed to marshal condition value: %v", err)
  158. }
  159. targetValue := gjson.ParseBytes(targetBytes)
  160. result, err := compareGjsonValues(value, targetValue, strings.ToLower(condition.Mode))
  161. if err != nil {
  162. return false, fmt.Errorf("comparison failed for path %s: %v", condition.Path, err)
  163. }
  164. if condition.Invert {
  165. result = !result
  166. }
  167. return result, nil
  168. }
  169. func processNegativeIndex(jsonStr string, path string) string {
  170. matches := negativeIndexRegexp.FindAllStringSubmatch(path, -1)
  171. if len(matches) == 0 {
  172. return path
  173. }
  174. result := path
  175. for _, match := range matches {
  176. negIndex := match[1]
  177. index, _ := strconv.Atoi(negIndex)
  178. arrayPath := strings.Split(path, negIndex)[0]
  179. if strings.HasSuffix(arrayPath, ".") {
  180. arrayPath = arrayPath[:len(arrayPath)-1]
  181. }
  182. array := gjson.Get(jsonStr, arrayPath)
  183. if array.IsArray() {
  184. length := len(array.Array())
  185. actualIndex := length + index
  186. if actualIndex >= 0 && actualIndex < length {
  187. result = strings.Replace(result, match[0], "."+strconv.Itoa(actualIndex), 1)
  188. }
  189. }
  190. }
  191. return result
  192. }
  193. // compareGjsonValues 直接比较两个gjson.Result,支持所有比较模式
  194. func compareGjsonValues(jsonValue, targetValue gjson.Result, mode string) (bool, error) {
  195. switch mode {
  196. case "full":
  197. return compareEqual(jsonValue, targetValue)
  198. case "prefix":
  199. return strings.HasPrefix(jsonValue.String(), targetValue.String()), nil
  200. case "suffix":
  201. return strings.HasSuffix(jsonValue.String(), targetValue.String()), nil
  202. case "contains":
  203. return strings.Contains(jsonValue.String(), targetValue.String()), nil
  204. case "gt":
  205. return compareNumeric(jsonValue, targetValue, "gt")
  206. case "gte":
  207. return compareNumeric(jsonValue, targetValue, "gte")
  208. case "lt":
  209. return compareNumeric(jsonValue, targetValue, "lt")
  210. case "lte":
  211. return compareNumeric(jsonValue, targetValue, "lte")
  212. default:
  213. return false, fmt.Errorf("unsupported comparison mode: %s", mode)
  214. }
  215. }
  216. func compareEqual(jsonValue, targetValue gjson.Result) (bool, error) {
  217. // 对null值特殊处理:两个都是null返回true,一个是null另一个不是返回false
  218. if jsonValue.Type == gjson.Null || targetValue.Type == gjson.Null {
  219. return jsonValue.Type == gjson.Null && targetValue.Type == gjson.Null, nil
  220. }
  221. // 对布尔值特殊处理
  222. if (jsonValue.Type == gjson.True || jsonValue.Type == gjson.False) &&
  223. (targetValue.Type == gjson.True || targetValue.Type == gjson.False) {
  224. return jsonValue.Bool() == targetValue.Bool(), nil
  225. }
  226. // 如果类型不同,报错
  227. if jsonValue.Type != targetValue.Type {
  228. return false, fmt.Errorf("compare for different types, got %v and %v", jsonValue.Type, targetValue.Type)
  229. }
  230. switch jsonValue.Type {
  231. case gjson.True, gjson.False:
  232. return jsonValue.Bool() == targetValue.Bool(), nil
  233. case gjson.Number:
  234. return jsonValue.Num == targetValue.Num, nil
  235. case gjson.String:
  236. return jsonValue.String() == targetValue.String(), nil
  237. default:
  238. return jsonValue.String() == targetValue.String(), nil
  239. }
  240. }
  241. func compareNumeric(jsonValue, targetValue gjson.Result, operator string) (bool, error) {
  242. // 只有数字类型才支持数值比较
  243. if jsonValue.Type != gjson.Number || targetValue.Type != gjson.Number {
  244. return false, fmt.Errorf("numeric comparison requires both values to be numbers, got %v and %v", jsonValue.Type, targetValue.Type)
  245. }
  246. jsonNum := jsonValue.Num
  247. targetNum := targetValue.Num
  248. switch operator {
  249. case "gt":
  250. return jsonNum > targetNum, nil
  251. case "gte":
  252. return jsonNum >= targetNum, nil
  253. case "lt":
  254. return jsonNum < targetNum, nil
  255. case "lte":
  256. return jsonNum <= targetNum, nil
  257. default:
  258. return false, fmt.Errorf("unsupported numeric operator: %s", operator)
  259. }
  260. }
  261. // applyOperationsLegacy 原参数覆盖方法
  262. func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{}) ([]byte, error) {
  263. reqMap := make(map[string]interface{})
  264. err := common.Unmarshal(jsonData, &reqMap)
  265. if err != nil {
  266. return nil, err
  267. }
  268. for key, value := range paramOverride {
  269. reqMap[key] = value
  270. }
  271. return common.Marshal(reqMap)
  272. }
  273. func applyOperations(jsonStr string, operations []ParamOperation, conditionContext map[string]interface{}) (string, error) {
  274. var contextJSON string
  275. if conditionContext != nil && len(conditionContext) > 0 {
  276. ctxBytes, err := common.Marshal(conditionContext)
  277. if err != nil {
  278. return "", fmt.Errorf("failed to marshal condition context: %v", err)
  279. }
  280. contextJSON = string(ctxBytes)
  281. }
  282. result := jsonStr
  283. for _, op := range operations {
  284. // 检查条件是否满足
  285. ok, err := checkConditions(result, contextJSON, op.Conditions, op.Logic)
  286. if err != nil {
  287. return "", err
  288. }
  289. if !ok {
  290. continue // 条件不满足,跳过当前操作
  291. }
  292. // 处理路径中的负数索引
  293. opPath := processNegativeIndex(result, op.Path)
  294. switch op.Mode {
  295. case "delete":
  296. result, err = sjson.Delete(result, opPath)
  297. case "set":
  298. if op.KeepOrigin && gjson.Get(result, opPath).Exists() {
  299. continue
  300. }
  301. result, err = sjson.Set(result, opPath, op.Value)
  302. case "move":
  303. opFrom := processNegativeIndex(result, op.From)
  304. opTo := processNegativeIndex(result, op.To)
  305. result, err = moveValue(result, opFrom, opTo)
  306. case "copy":
  307. if op.From == "" || op.To == "" {
  308. return "", fmt.Errorf("copy from/to is required")
  309. }
  310. opFrom := processNegativeIndex(result, op.From)
  311. opTo := processNegativeIndex(result, op.To)
  312. result, err = copyValue(result, opFrom, opTo)
  313. case "prepend":
  314. result, err = modifyValue(result, opPath, op.Value, op.KeepOrigin, true)
  315. case "append":
  316. result, err = modifyValue(result, opPath, op.Value, op.KeepOrigin, false)
  317. case "trim_prefix":
  318. result, err = trimStringValue(result, opPath, op.Value, true)
  319. case "trim_suffix":
  320. result, err = trimStringValue(result, opPath, op.Value, false)
  321. case "ensure_prefix":
  322. result, err = ensureStringAffix(result, opPath, op.Value, true)
  323. case "ensure_suffix":
  324. result, err = ensureStringAffix(result, opPath, op.Value, false)
  325. case "trim_space":
  326. result, err = transformStringValue(result, opPath, strings.TrimSpace)
  327. case "to_lower":
  328. result, err = transformStringValue(result, opPath, strings.ToLower)
  329. case "to_upper":
  330. result, err = transformStringValue(result, opPath, strings.ToUpper)
  331. case "replace":
  332. result, err = replaceStringValue(result, opPath, op.From, op.To)
  333. case "regex_replace":
  334. result, err = regexReplaceStringValue(result, opPath, op.From, op.To)
  335. default:
  336. return "", fmt.Errorf("unknown operation: %s", op.Mode)
  337. }
  338. if err != nil {
  339. return "", fmt.Errorf("operation %s failed: %v", op.Mode, err)
  340. }
  341. }
  342. return result, nil
  343. }
  344. func moveValue(jsonStr, fromPath, toPath string) (string, error) {
  345. sourceValue := gjson.Get(jsonStr, fromPath)
  346. if !sourceValue.Exists() {
  347. return jsonStr, fmt.Errorf("source path does not exist: %s", fromPath)
  348. }
  349. result, err := sjson.Set(jsonStr, toPath, sourceValue.Value())
  350. if err != nil {
  351. return "", err
  352. }
  353. return sjson.Delete(result, fromPath)
  354. }
  355. func copyValue(jsonStr, fromPath, toPath string) (string, error) {
  356. sourceValue := gjson.Get(jsonStr, fromPath)
  357. if !sourceValue.Exists() {
  358. return jsonStr, fmt.Errorf("source path does not exist: %s", fromPath)
  359. }
  360. return sjson.Set(jsonStr, toPath, sourceValue.Value())
  361. }
  362. func modifyValue(jsonStr, path string, value interface{}, keepOrigin, isPrepend bool) (string, error) {
  363. current := gjson.Get(jsonStr, path)
  364. switch {
  365. case current.IsArray():
  366. return modifyArray(jsonStr, path, value, isPrepend)
  367. case current.Type == gjson.String:
  368. return modifyString(jsonStr, path, value, isPrepend)
  369. case current.Type == gjson.JSON:
  370. return mergeObjects(jsonStr, path, value, keepOrigin)
  371. }
  372. return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
  373. }
  374. func modifyArray(jsonStr, path string, value interface{}, isPrepend bool) (string, error) {
  375. current := gjson.Get(jsonStr, path)
  376. var newArray []interface{}
  377. // 添加新值
  378. addValue := func() {
  379. if arr, ok := value.([]interface{}); ok {
  380. newArray = append(newArray, arr...)
  381. } else {
  382. newArray = append(newArray, value)
  383. }
  384. }
  385. // 添加原值
  386. addOriginal := func() {
  387. current.ForEach(func(_, val gjson.Result) bool {
  388. newArray = append(newArray, val.Value())
  389. return true
  390. })
  391. }
  392. if isPrepend {
  393. addValue()
  394. addOriginal()
  395. } else {
  396. addOriginal()
  397. addValue()
  398. }
  399. return sjson.Set(jsonStr, path, newArray)
  400. }
  401. func modifyString(jsonStr, path string, value interface{}, isPrepend bool) (string, error) {
  402. current := gjson.Get(jsonStr, path)
  403. valueStr := fmt.Sprintf("%v", value)
  404. var newStr string
  405. if isPrepend {
  406. newStr = valueStr + current.String()
  407. } else {
  408. newStr = current.String() + valueStr
  409. }
  410. return sjson.Set(jsonStr, path, newStr)
  411. }
  412. func trimStringValue(jsonStr, path string, value interface{}, isPrefix bool) (string, error) {
  413. current := gjson.Get(jsonStr, path)
  414. if current.Type != gjson.String {
  415. return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
  416. }
  417. if value == nil {
  418. return jsonStr, fmt.Errorf("trim value is required")
  419. }
  420. valueStr := fmt.Sprintf("%v", value)
  421. var newStr string
  422. if isPrefix {
  423. newStr = strings.TrimPrefix(current.String(), valueStr)
  424. } else {
  425. newStr = strings.TrimSuffix(current.String(), valueStr)
  426. }
  427. return sjson.Set(jsonStr, path, newStr)
  428. }
  429. func ensureStringAffix(jsonStr, path string, value interface{}, isPrefix bool) (string, error) {
  430. current := gjson.Get(jsonStr, path)
  431. if current.Type != gjson.String {
  432. return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
  433. }
  434. if value == nil {
  435. return jsonStr, fmt.Errorf("ensure value is required")
  436. }
  437. valueStr := fmt.Sprintf("%v", value)
  438. if valueStr == "" {
  439. return jsonStr, fmt.Errorf("ensure value is required")
  440. }
  441. currentStr := current.String()
  442. if isPrefix {
  443. if strings.HasPrefix(currentStr, valueStr) {
  444. return jsonStr, nil
  445. }
  446. return sjson.Set(jsonStr, path, valueStr+currentStr)
  447. }
  448. if strings.HasSuffix(currentStr, valueStr) {
  449. return jsonStr, nil
  450. }
  451. return sjson.Set(jsonStr, path, currentStr+valueStr)
  452. }
  453. func transformStringValue(jsonStr, path string, transform func(string) string) (string, error) {
  454. current := gjson.Get(jsonStr, path)
  455. if current.Type != gjson.String {
  456. return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
  457. }
  458. return sjson.Set(jsonStr, path, transform(current.String()))
  459. }
  460. func replaceStringValue(jsonStr, path, from, to string) (string, error) {
  461. current := gjson.Get(jsonStr, path)
  462. if current.Type != gjson.String {
  463. return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
  464. }
  465. if from == "" {
  466. return jsonStr, fmt.Errorf("replace from is required")
  467. }
  468. return sjson.Set(jsonStr, path, strings.ReplaceAll(current.String(), from, to))
  469. }
  470. func regexReplaceStringValue(jsonStr, path, pattern, replacement string) (string, error) {
  471. current := gjson.Get(jsonStr, path)
  472. if current.Type != gjson.String {
  473. return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
  474. }
  475. if pattern == "" {
  476. return jsonStr, fmt.Errorf("regex pattern is required")
  477. }
  478. re, err := regexp.Compile(pattern)
  479. if err != nil {
  480. return jsonStr, err
  481. }
  482. return sjson.Set(jsonStr, path, re.ReplaceAllString(current.String(), replacement))
  483. }
  484. func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (string, error) {
  485. current := gjson.Get(jsonStr, path)
  486. var currentMap, newMap map[string]interface{}
  487. // 解析当前值
  488. if err := common.Unmarshal([]byte(current.Raw), &currentMap); err != nil {
  489. return "", err
  490. }
  491. // 解析新值
  492. switch v := value.(type) {
  493. case map[string]interface{}:
  494. newMap = v
  495. default:
  496. jsonBytes, _ := common.Marshal(v)
  497. if err := common.Unmarshal(jsonBytes, &newMap); err != nil {
  498. return "", err
  499. }
  500. }
  501. // 合并
  502. result := make(map[string]interface{})
  503. for k, v := range currentMap {
  504. result[k] = v
  505. }
  506. for k, v := range newMap {
  507. if !keepOrigin || result[k] == nil {
  508. result[k] = v
  509. }
  510. }
  511. return sjson.Set(jsonStr, path, result)
  512. }
  513. // BuildParamOverrideContext 提供 ApplyParamOverride 可用的上下文信息。
  514. // 目前内置以下字段:
  515. // - model:优先使用上游模型名(UpstreamModelName),若不存在则回落到原始模型名(OriginModelName)。
  516. // - upstream_model:始终为通道映射后的上游模型名。
  517. // - original_model:请求最初指定的模型名。
  518. func BuildParamOverrideContext(info *RelayInfo) map[string]interface{} {
  519. if info == nil || info.ChannelMeta == nil {
  520. return nil
  521. }
  522. ctx := make(map[string]interface{})
  523. if info.UpstreamModelName != "" {
  524. ctx["model"] = info.UpstreamModelName
  525. ctx["upstream_model"] = info.UpstreamModelName
  526. }
  527. if info.OriginModelName != "" {
  528. ctx["original_model"] = info.OriginModelName
  529. if _, exists := ctx["model"]; !exists {
  530. ctx["model"] = info.OriginModelName
  531. }
  532. }
  533. if len(ctx) == 0 {
  534. return nil
  535. }
  536. return ctx
  537. }