瀏覽代碼

Merge pull request #1603 from nekohy/feats-the-flexable-params-override

feats: the flexable params override and compatible format
Calcium-Ion 4 月之前
父節點
當前提交
73a7b33864

+ 4 - 0
go.mod

@@ -31,6 +31,8 @@ require (
 	github.com/shopspring/decimal v1.4.0
 	github.com/stripe/stripe-go/v81 v81.4.0
 	github.com/thanhpk/randstr v1.0.6
+	github.com/tidwall/gjson v1.18.0
+	github.com/tidwall/sjson v1.2.5
 	github.com/tiktoken-go/tokenizer v0.6.2
 	golang.org/x/crypto v0.35.0
 	golang.org/x/image v0.23.0
@@ -82,6 +84,8 @@ require (
 	github.com/modern-go/reflect2 v1.0.2 // indirect
 	github.com/pelletier/go-toml/v2 v2.2.1 // indirect
 	github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
+	github.com/tidwall/match v1.1.1 // indirect
+	github.com/tidwall/pretty v1.2.0 // indirect
 	github.com/tklauser/go-sysconf v0.3.12 // indirect
 	github.com/tklauser/numcpus v0.6.1 // indirect
 	github.com/twitchyliquid64/golang-asm v0.15.1 // indirect

+ 9 - 0
go.sum

@@ -204,6 +204,15 @@ github.com/stripe/stripe-go/v81 v81.4.0 h1:AuD9XzdAvl193qUCSaLocf8H+nRopOouXhxqJ
 github.com/stripe/stripe-go/v81 v81.4.0/go.mod h1:C/F4jlmnGNacvYtBp/LUHCvVUJEZffFQCobkzwY1WOo=
 github.com/thanhpk/randstr v1.0.6 h1:psAOktJFD4vV9NEVb3qkhRSMvYh4ORRaj1+w/hn4B+o=
 github.com/thanhpk/randstr v1.0.6/go.mod h1:M/H2P1eNLZzlDwAzpkkkUvoyNNMbzRGhESZuEQk3r0U=
+github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
+github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
+github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
+github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
+github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
+github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
+github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
+github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
+github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
 github.com/tiktoken-go/tokenizer v0.6.2 h1:t0GN2DvcUZSFWT/62YOgoqb10y7gSXBGs0A+4VCQK+g=
 github.com/tiktoken-go/tokenizer v0.6.2/go.mod h1:6UCYI/DtOallbmL7sSy30p6YQv60qNyU/4aVigPOx6w=
 github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=

+ 1 - 6
relay/claude_handler.go

@@ -83,12 +83,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
 
 		// apply param override
 		if len(info.ParamOverride) > 0 {
-			reqMap := make(map[string]interface{})
-			_ = common.Unmarshal(jsonData, &reqMap)
-			for key, value := range info.ParamOverride {
-				reqMap[key] = value
-			}
-			jsonData, err = common.Marshal(reqMap)
+			jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
 			if err != nil {
 				return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
 			}

+ 310 - 0
relay/common/override.go

@@ -0,0 +1,310 @@
+package common
+
+import (
+	"encoding/json"
+	"fmt"
+	"github.com/tidwall/gjson"
+	"github.com/tidwall/sjson"
+	"strings"
+)
+
+type ConditionOperation struct {
+	Path  string `json:"path"`  // JSON路径
+	Mode  string `json:"mode"`  // full, prefix, suffix, contains
+	Value string `json:"value"` // 匹配的值
+}
+
+type ParamOperation struct {
+	Path       string               `json:"path"`
+	Mode       string               `json:"mode"` // delete, set, move, prepend, append
+	Value      interface{}          `json:"value"`
+	KeepOrigin bool                 `json:"keep_origin"`
+	From       string               `json:"from,omitempty"`
+	To         string               `json:"to,omitempty"`
+	Conditions []ConditionOperation `json:"conditions,omitempty"` // 条件列表
+	Logic      string               `json:"logic,omitempty"`      // AND, OR (默认OR)
+}
+
+func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}) ([]byte, error) {
+	if len(paramOverride) == 0 {
+		return jsonData, nil
+	}
+
+	// 尝试断言为操作格式
+	if operations, ok := tryParseOperations(paramOverride); ok {
+		// 使用新方法
+		result, err := applyOperations(string(jsonData), operations)
+		if err != nil {
+			// 新方法失败,回退到旧方法
+			return applyOperationsLegacy(jsonData, paramOverride)
+		}
+		return []byte(result), nil
+	}
+
+	// 直接使用旧方法
+	return applyOperationsLegacy(jsonData, paramOverride)
+}
+
+func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation, bool) {
+	// 检查是否包含 "operations" 字段
+	if opsValue, exists := paramOverride["operations"]; exists {
+		if opsSlice, ok := opsValue.([]interface{}); ok {
+			var operations []ParamOperation
+			for _, op := range opsSlice {
+				if opMap, ok := op.(map[string]interface{}); ok {
+					operation := ParamOperation{}
+
+					// 断言必要字段
+					if path, ok := opMap["path"].(string); ok {
+						operation.Path = path
+					}
+					if mode, ok := opMap["mode"].(string); ok {
+						operation.Mode = mode
+					} else {
+						return nil, false // mode 是必需的
+					}
+
+					// 可选字段
+					if value, exists := opMap["value"]; exists {
+						operation.Value = value
+					}
+					if keepOrigin, ok := opMap["keep_origin"].(bool); ok {
+						operation.KeepOrigin = keepOrigin
+					}
+					if from, ok := opMap["from"].(string); ok {
+						operation.From = from
+					}
+					if to, ok := opMap["to"].(string); ok {
+						operation.To = to
+					}
+					if logic, ok := opMap["logic"].(string); ok {
+						operation.Logic = logic
+					} else {
+						operation.Logic = "OR" // 默认为OR
+					}
+
+					// 解析条件
+					if conditions, exists := opMap["conditions"]; exists {
+						if condSlice, ok := conditions.([]interface{}); ok {
+							for _, cond := range condSlice {
+								if condMap, ok := cond.(map[string]interface{}); ok {
+									condition := ConditionOperation{}
+									if path, ok := condMap["path"].(string); ok {
+										condition.Path = path
+									}
+									if mode, ok := condMap["mode"].(string); ok {
+										condition.Mode = mode
+									}
+									if value, ok := condMap["value"].(string); ok {
+										condition.Value = value
+									}
+									operation.Conditions = append(operation.Conditions, condition)
+								}
+							}
+						}
+					}
+
+					operations = append(operations, operation)
+				} else {
+					return nil, false
+				}
+			}
+			return operations, true
+		}
+	}
+
+	return nil, false
+}
+
+func checkConditions(jsonStr string, conditions []ConditionOperation, logic string) bool {
+	if len(conditions) == 0 {
+		return true // 没有条件,直接通过
+	}
+	results := make([]bool, len(conditions))
+
+	for i, condition := range conditions {
+		results[i] = checkSingleCondition(jsonStr, condition)
+	}
+	if strings.ToUpper(logic) == "AND" {
+		for _, result := range results {
+			if !result {
+				return false
+			}
+		}
+		return true
+	} else {
+		for _, result := range results {
+			if result {
+				return true
+			}
+		}
+		return false
+	}
+}
+
+func checkSingleCondition(jsonStr string, condition ConditionOperation) bool {
+	value := gjson.Get(jsonStr, condition.Path)
+	if !value.Exists() {
+		return false
+	}
+
+	valueStr := value.String()
+	targetStr := condition.Value
+
+	switch strings.ToLower(condition.Mode) {
+	case "full":
+		return valueStr == targetStr
+	case "prefix":
+		return strings.HasPrefix(valueStr, targetStr)
+	case "suffix":
+		return strings.HasSuffix(valueStr, targetStr)
+	case "contains":
+		return strings.Contains(valueStr, targetStr)
+	default:
+		return valueStr == targetStr // 默认精准匹配
+	}
+}
+
+// applyOperationsLegacy 原参数覆盖方法
+func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{}) ([]byte, error) {
+	reqMap := make(map[string]interface{})
+	err := json.Unmarshal(jsonData, &reqMap)
+	if err != nil {
+		return nil, err
+	}
+
+	for key, value := range paramOverride {
+		reqMap[key] = value
+	}
+
+	return json.Marshal(reqMap)
+}
+
+func applyOperations(jsonStr string, operations []ParamOperation) (string, error) {
+	result := jsonStr
+	for _, op := range operations {
+		// 检查条件是否满足
+		if !checkConditions(result, op.Conditions, op.Logic) {
+			continue // 条件不满足,跳过当前操作
+		}
+
+		var err error
+		switch op.Mode {
+		case "delete":
+			result, err = sjson.Delete(result, op.Path)
+		case "set":
+			if op.KeepOrigin && gjson.Get(result, op.Path).Exists() {
+				continue
+			}
+			result, err = sjson.Set(result, op.Path, op.Value)
+		case "move":
+			result, err = moveValue(result, op.From, op.To)
+		case "prepend":
+			result, err = modifyValue(result, op.Path, op.Value, op.KeepOrigin, true)
+		case "append":
+			result, err = modifyValue(result, op.Path, op.Value, op.KeepOrigin, false)
+		default:
+			return "", fmt.Errorf("unknown operation: %s", op.Mode)
+		}
+		if err != nil {
+			return "", fmt.Errorf("operation %s failed: %v", op.Mode, err)
+		}
+	}
+	return result, nil
+}
+
+func moveValue(jsonStr, fromPath, toPath string) (string, error) {
+	sourceValue := gjson.Get(jsonStr, fromPath)
+	if !sourceValue.Exists() {
+		return jsonStr, fmt.Errorf("source path does not exist: %s", fromPath)
+	}
+	result, err := sjson.Set(jsonStr, toPath, sourceValue.Value())
+	if err != nil {
+		return "", err
+	}
+	return sjson.Delete(result, fromPath)
+}
+
+func modifyValue(jsonStr, path string, value interface{}, keepOrigin, isPrepend bool) (string, error) {
+	current := gjson.Get(jsonStr, path)
+	switch {
+	case current.IsArray():
+		return modifyArray(jsonStr, path, value, isPrepend)
+	case current.Type == gjson.String:
+		return modifyString(jsonStr, path, value, isPrepend)
+	case current.Type == gjson.JSON:
+		return mergeObjects(jsonStr, path, value, keepOrigin)
+	}
+	return jsonStr, fmt.Errorf("operation not supported for type: %v", current.Type)
+}
+
+func modifyArray(jsonStr, path string, value interface{}, isPrepend bool) (string, error) {
+	current := gjson.Get(jsonStr, path)
+	var newArray []interface{}
+	// 添加新值
+	addValue := func() {
+		if arr, ok := value.([]interface{}); ok {
+			newArray = append(newArray, arr...)
+		} else {
+			newArray = append(newArray, value)
+		}
+	}
+	// 添加原值
+	addOriginal := func() {
+		current.ForEach(func(_, val gjson.Result) bool {
+			newArray = append(newArray, val.Value())
+			return true
+		})
+	}
+	if isPrepend {
+		addValue()
+		addOriginal()
+	} else {
+		addOriginal()
+		addValue()
+	}
+	return sjson.Set(jsonStr, path, newArray)
+}
+
+func modifyString(jsonStr, path string, value interface{}, isPrepend bool) (string, error) {
+	current := gjson.Get(jsonStr, path)
+	valueStr := fmt.Sprintf("%v", value)
+	var newStr string
+	if isPrepend {
+		newStr = valueStr + current.String()
+	} else {
+		newStr = current.String() + valueStr
+	}
+	return sjson.Set(jsonStr, path, newStr)
+}
+
+func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (string, error) {
+	current := gjson.Get(jsonStr, path)
+	var currentMap, newMap map[string]interface{}
+
+	// 解析当前值
+	if err := json.Unmarshal([]byte(current.Raw), &currentMap); err != nil {
+		return "", err
+	}
+	// 解析新值
+	switch v := value.(type) {
+	case map[string]interface{}:
+		newMap = v
+	default:
+		jsonBytes, _ := json.Marshal(v)
+		if err := json.Unmarshal(jsonBytes, &newMap); err != nil {
+			return "", err
+		}
+	}
+	// 合并
+	result := make(map[string]interface{})
+	for k, v := range currentMap {
+		result[k] = v
+	}
+	for k, v := range newMap {
+		if !keepOrigin || result[k] == nil {
+			result[k] = v
+		}
+	}
+	return sjson.Set(jsonStr, path, result)
+}

+ 1 - 6
relay/gemini_handler.go

@@ -123,12 +123,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
 
 		// apply param override
 		if len(info.ParamOverride) > 0 {
-			reqMap := make(map[string]interface{})
-			_ = common.Unmarshal(jsonData, &reqMap)
-			for key, value := range info.ParamOverride {
-				reqMap[key] = value
-			}
-			jsonData, err = common.Marshal(reqMap)
+			jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
 			if err != nil {
 				return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
 			}

+ 1 - 6
relay/image_handler.go

@@ -63,12 +63,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
 
 			// apply param override
 			if len(info.ParamOverride) > 0 {
-				reqMap := make(map[string]interface{})
-				_ = common.Unmarshal(jsonData, &reqMap)
-				for key, value := range info.ParamOverride {
-					reqMap[key] = value
-				}
-				jsonData, err = common.Marshal(reqMap)
+				jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
 				if err != nil {
 					return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
 				}

+ 1 - 6
relay/relay-text.go

@@ -133,12 +133,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
 
 		// apply param override
 		if len(info.ParamOverride) > 0 {
-			reqMap := make(map[string]interface{})
-			_ = common.Unmarshal(jsonData, &reqMap)
-			for key, value := range info.ParamOverride {
-				reqMap[key] = value
-			}
-			jsonData, err = common.Marshal(reqMap)
+			jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
 			if err != nil {
 				return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
 			}

+ 1 - 6
relay/rerank_handler.go

@@ -62,12 +62,7 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
 
 		// apply param override
 		if len(info.ParamOverride) > 0 {
-			reqMap := make(map[string]interface{})
-			_ = common.Unmarshal(jsonData, &reqMap)
-			for key, value := range info.ParamOverride {
-				reqMap[key] = value
-			}
-			jsonData, err = common.Marshal(reqMap)
+			jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
 			if err != nil {
 				return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
 			}

+ 1 - 9
relay/responses_handler.go

@@ -53,18 +53,10 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
 		}
 		// apply param override
 		if len(info.ParamOverride) > 0 {
-			reqMap := make(map[string]interface{})
-			err = common.Unmarshal(jsonData, &reqMap)
+			jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
 			if err != nil {
 				return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
 			}
-			for key, value := range info.ParamOverride {
-				reqMap[key] = value
-			}
-			jsonData, err = common.Marshal(reqMap)
-			if err != nil {
-				return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
-			}
 		}
 
 		if common.DebugEnabled {

+ 45 - 17
web/src/components/table/channels/modals/EditChannelModal.jsx

@@ -1653,25 +1653,53 @@ const EditChannelModal = (props) => {
                   />
 
                   <Form.TextArea
-                    field='param_override'
-                    label={t('参数覆盖')}
-                    placeholder={
-                      t('此项可选,用于覆盖请求参数。不支持覆盖 stream 参数。为一个 JSON 字符串,例如:') +
-                      '\n{\n  "temperature": 0\n}'
-                    }
-                    autosize
-                    onChange={(value) => handleInputChange('param_override', value)}
-                    extraText={
-                      <Text
-                        className="!text-semi-color-primary cursor-pointer"
-                        onClick={() => handleInputChange('param_override', JSON.stringify({ temperature: 0 }, null, 2))}
-                      >
-                        {t('填入模板')}
-                      </Text>
-                    }
-                    showClear
+                      field='param_override'
+                      label={t('参数覆盖')}
+                      placeholder={
+                          t('此项可选,用于覆盖请求参数。不支持覆盖 stream 参数') +
+                          '\n' + t('旧格式(直接覆盖):') +
+                          '\n{\n  "temperature": 0,\n  "max_tokens": 1000\n}' +
+                          '\n\n' + t('新格式(支持条件判断与json自定义):') +
+                          '\n{\n  "operations": [\n    {\n      "path": "temperature",\n      "mode": "set",\n      "value": 0.7,\n      "conditions": [\n        {\n          "path": "model",\n          "mode": "prefix",\n          "value": "gpt"\n        }\n      ]\n    }\n  ]\n}'
+                      }
+                      autosize
+                      onChange={(value) => handleInputChange('param_override', value)}
+                      extraText={
+                        <div className="flex gap-2 flex-wrap">
+                          <Text
+                              className="!text-semi-color-primary cursor-pointer"
+                              onClick={() => handleInputChange('param_override', JSON.stringify({ temperature: 0 }, null, 2))}
+                          >
+                            {t('旧格式模板')}
+                          </Text>
+                          <Text
+                              className="!text-semi-color-primary cursor-pointer"
+                              onClick={() => handleInputChange('param_override', JSON.stringify({
+                                operations: [
+                                  {
+                                    path: "temperature",
+                                    mode: "set",
+                                    value: 0.7,
+                                    conditions: [
+                                      {
+                                        path: "model",
+                                        mode: "prefix",
+                                        value: "gpt"
+                                      }
+                                    ],
+                                    logic: "AND"
+                                  }
+                                ]
+                              }, null, 2))}
+                          >
+                            {t('新格式模板')}
+                          </Text>
+                        </div>
+                      }
+                      showClear
                   />
 
+
                   <JSONEditor
                     key={`status_code_mapping-${isEdit ? channelId : 'new'}`}
                     field='status_code_mapping'