|
|
@@ -1,12 +1,12 @@
|
|
|
package common
|
|
|
|
|
|
import (
|
|
|
- "encoding/json"
|
|
|
"fmt"
|
|
|
"regexp"
|
|
|
"strconv"
|
|
|
"strings"
|
|
|
|
|
|
+ "github.com/QuantumNous/new-api/common"
|
|
|
"github.com/tidwall/gjson"
|
|
|
"github.com/tidwall/sjson"
|
|
|
)
|
|
|
@@ -30,7 +30,7 @@ type ParamOperation struct {
|
|
|
Logic string `json:"logic,omitempty"` // AND, OR (默认OR)
|
|
|
}
|
|
|
|
|
|
-func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}) ([]byte, error) {
|
|
|
+func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}, conditionContext map[string]interface{}) ([]byte, error) {
|
|
|
if len(paramOverride) == 0 {
|
|
|
return jsonData, nil
|
|
|
}
|
|
|
@@ -38,7 +38,7 @@ func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}) (
|
|
|
// 尝试断言为操作格式
|
|
|
if operations, ok := tryParseOperations(paramOverride); ok {
|
|
|
// 使用新方法
|
|
|
- result, err := applyOperations(string(jsonData), operations)
|
|
|
+ result, err := applyOperations(string(jsonData), operations, conditionContext)
|
|
|
return []byte(result), err
|
|
|
}
|
|
|
|
|
|
@@ -123,13 +123,13 @@ func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation,
|
|
|
return nil, false
|
|
|
}
|
|
|
|
|
|
-func checkConditions(jsonStr string, conditions []ConditionOperation, logic string) (bool, error) {
|
|
|
+func checkConditions(jsonStr, contextJSON string, conditions []ConditionOperation, logic string) (bool, error) {
|
|
|
if len(conditions) == 0 {
|
|
|
return true, nil // 没有条件,直接通过
|
|
|
}
|
|
|
results := make([]bool, len(conditions))
|
|
|
for i, condition := range conditions {
|
|
|
- result, err := checkSingleCondition(jsonStr, condition)
|
|
|
+ result, err := checkSingleCondition(jsonStr, contextJSON, condition)
|
|
|
if err != nil {
|
|
|
return false, err
|
|
|
}
|
|
|
@@ -153,10 +153,13 @@ func checkConditions(jsonStr string, conditions []ConditionOperation, logic stri
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func checkSingleCondition(jsonStr string, condition ConditionOperation) (bool, error) {
|
|
|
+func checkSingleCondition(jsonStr, contextJSON string, condition ConditionOperation) (bool, error) {
|
|
|
// 处理负数索引
|
|
|
path := processNegativeIndex(jsonStr, condition.Path)
|
|
|
value := gjson.Get(jsonStr, path)
|
|
|
+ if !value.Exists() && contextJSON != "" {
|
|
|
+ value = gjson.Get(contextJSON, condition.Path)
|
|
|
+ }
|
|
|
if !value.Exists() {
|
|
|
if condition.PassMissingKey {
|
|
|
return true, nil
|
|
|
@@ -165,7 +168,7 @@ func checkSingleCondition(jsonStr string, condition ConditionOperation) (bool, e
|
|
|
}
|
|
|
|
|
|
// 利用gjson的类型解析
|
|
|
- targetBytes, err := json.Marshal(condition.Value)
|
|
|
+ targetBytes, err := common.Marshal(condition.Value)
|
|
|
if err != nil {
|
|
|
return false, fmt.Errorf("failed to marshal condition value: %v", err)
|
|
|
}
|
|
|
@@ -292,7 +295,7 @@ func compareNumeric(jsonValue, targetValue gjson.Result, operator string) (bool,
|
|
|
// applyOperationsLegacy 原参数覆盖方法
|
|
|
func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{}) ([]byte, error) {
|
|
|
reqMap := make(map[string]interface{})
|
|
|
- err := json.Unmarshal(jsonData, &reqMap)
|
|
|
+ err := common.Unmarshal(jsonData, &reqMap)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
@@ -301,14 +304,23 @@ func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{}
|
|
|
reqMap[key] = value
|
|
|
}
|
|
|
|
|
|
- return json.Marshal(reqMap)
|
|
|
+ return common.Marshal(reqMap)
|
|
|
}
|
|
|
|
|
|
-func applyOperations(jsonStr string, operations []ParamOperation) (string, error) {
|
|
|
+func applyOperations(jsonStr string, operations []ParamOperation, conditionContext map[string]interface{}) (string, error) {
|
|
|
+ var contextJSON string
|
|
|
+ if conditionContext != nil && len(conditionContext) > 0 {
|
|
|
+ ctxBytes, err := common.Marshal(conditionContext)
|
|
|
+ if err != nil {
|
|
|
+ return "", fmt.Errorf("failed to marshal condition context: %v", err)
|
|
|
+ }
|
|
|
+ contextJSON = string(ctxBytes)
|
|
|
+ }
|
|
|
+
|
|
|
result := jsonStr
|
|
|
for _, op := range operations {
|
|
|
// 检查条件是否满足
|
|
|
- ok, err := checkConditions(result, op.Conditions, op.Logic)
|
|
|
+ ok, err := checkConditions(result, contextJSON, op.Conditions, op.Logic)
|
|
|
if err != nil {
|
|
|
return "", err
|
|
|
}
|
|
|
@@ -414,7 +426,7 @@ func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (str
|
|
|
var currentMap, newMap map[string]interface{}
|
|
|
|
|
|
// 解析当前值
|
|
|
- if err := json.Unmarshal([]byte(current.Raw), ¤tMap); err != nil {
|
|
|
+ if err := common.Unmarshal([]byte(current.Raw), ¤tMap); err != nil {
|
|
|
return "", err
|
|
|
}
|
|
|
// 解析新值
|
|
|
@@ -422,8 +434,8 @@ func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (str
|
|
|
case map[string]interface{}:
|
|
|
newMap = v
|
|
|
default:
|
|
|
- jsonBytes, _ := json.Marshal(v)
|
|
|
- if err := json.Unmarshal(jsonBytes, &newMap); err != nil {
|
|
|
+ jsonBytes, _ := common.Marshal(v)
|
|
|
+ if err := common.Unmarshal(jsonBytes, &newMap); err != nil {
|
|
|
return "", err
|
|
|
}
|
|
|
}
|
|
|
@@ -439,3 +451,31 @@ func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (str
|
|
|
}
|
|
|
return sjson.Set(jsonStr, path, result)
|
|
|
}
|
|
|
+
|
|
|
+// BuildParamOverrideContext 提供 ApplyParamOverride 可用的上下文信息。
|
|
|
+// 目前内置以下字段:
|
|
|
+// - model:优先使用上游模型名(UpstreamModelName),若不存在则回落到原始模型名(OriginModelName)。
|
|
|
+// - upstream_model:始终为通道映射后的上游模型名。
|
|
|
+// - original_model:请求最初指定的模型名。
|
|
|
+func BuildParamOverrideContext(info *RelayInfo) map[string]interface{} {
|
|
|
+ if info == nil || info.ChannelMeta == nil {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+
|
|
|
+ ctx := make(map[string]interface{})
|
|
|
+ if info.UpstreamModelName != "" {
|
|
|
+ ctx["model"] = info.UpstreamModelName
|
|
|
+ ctx["upstream_model"] = info.UpstreamModelName
|
|
|
+ }
|
|
|
+ if info.OriginModelName != "" {
|
|
|
+ ctx["original_model"] = info.OriginModelName
|
|
|
+ if _, exists := ctx["model"]; !exists {
|
|
|
+ ctx["model"] = info.OriginModelName
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if len(ctx) == 0 {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ return ctx
|
|
|
+}
|