Преглед изворни кода

Merge pull request #2276 from seefs001/feature/internal_params

feat: embedding param override && internal params
Calcium-Ion пре 1 месец
родитељ
комит
dbfcb441f7

+ 1 - 1
relay/claude_handler.go

@@ -123,7 +123,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
 
 		// apply param override
 		if len(info.ParamOverride) > 0 {
-			jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
+			jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
 			if err != nil {
 				return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
 			}

+ 54 - 14
relay/common/override.go

@@ -1,12 +1,12 @@
 package common
 
 import (
-	"encoding/json"
 	"fmt"
 	"regexp"
 	"strconv"
 	"strings"
 
+	"github.com/QuantumNous/new-api/common"
 	"github.com/tidwall/gjson"
 	"github.com/tidwall/sjson"
 )
@@ -30,7 +30,7 @@ type ParamOperation struct {
 	Logic      string               `json:"logic,omitempty"`      // AND, OR (默认OR)
 }
 
-func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}) ([]byte, error) {
+func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}, conditionContext map[string]interface{}) ([]byte, error) {
 	if len(paramOverride) == 0 {
 		return jsonData, nil
 	}
@@ -38,7 +38,7 @@ func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}) (
 	// 尝试断言为操作格式
 	if operations, ok := tryParseOperations(paramOverride); ok {
 		// 使用新方法
-		result, err := applyOperations(string(jsonData), operations)
+		result, err := applyOperations(string(jsonData), operations, conditionContext)
 		return []byte(result), err
 	}
 
@@ -123,13 +123,13 @@ func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation,
 	return nil, false
 }
 
-func checkConditions(jsonStr string, conditions []ConditionOperation, logic string) (bool, error) {
+func checkConditions(jsonStr, contextJSON string, conditions []ConditionOperation, logic string) (bool, error) {
 	if len(conditions) == 0 {
 		return true, nil // 没有条件,直接通过
 	}
 	results := make([]bool, len(conditions))
 	for i, condition := range conditions {
-		result, err := checkSingleCondition(jsonStr, condition)
+		result, err := checkSingleCondition(jsonStr, contextJSON, condition)
 		if err != nil {
 			return false, err
 		}
@@ -153,10 +153,13 @@ func checkConditions(jsonStr string, conditions []ConditionOperation, logic stri
 	}
 }
 
-func checkSingleCondition(jsonStr string, condition ConditionOperation) (bool, error) {
+func checkSingleCondition(jsonStr, contextJSON string, condition ConditionOperation) (bool, error) {
 	// 处理负数索引
 	path := processNegativeIndex(jsonStr, condition.Path)
 	value := gjson.Get(jsonStr, path)
+	if !value.Exists() && contextJSON != "" {
+		value = gjson.Get(contextJSON, condition.Path)
+	}
 	if !value.Exists() {
 		if condition.PassMissingKey {
 			return true, nil
@@ -165,7 +168,7 @@ func checkSingleCondition(jsonStr string, condition ConditionOperation) (bool, e
 	}
 
 	// 利用gjson的类型解析
-	targetBytes, err := json.Marshal(condition.Value)
+	targetBytes, err := common.Marshal(condition.Value)
 	if err != nil {
 		return false, fmt.Errorf("failed to marshal condition value: %v", err)
 	}
@@ -292,7 +295,7 @@ func compareNumeric(jsonValue, targetValue gjson.Result, operator string) (bool,
 // applyOperationsLegacy 原参数覆盖方法
 func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{}) ([]byte, error) {
 	reqMap := make(map[string]interface{})
-	err := json.Unmarshal(jsonData, &reqMap)
+	err := common.Unmarshal(jsonData, &reqMap)
 	if err != nil {
 		return nil, err
 	}
@@ -301,14 +304,23 @@ func applyOperationsLegacy(jsonData []byte, paramOverride map[string]interface{}
 		reqMap[key] = value
 	}
 
-	return json.Marshal(reqMap)
+	return common.Marshal(reqMap)
 }
 
-func applyOperations(jsonStr string, operations []ParamOperation) (string, error) {
+func applyOperations(jsonStr string, operations []ParamOperation, conditionContext map[string]interface{}) (string, error) {
+	var contextJSON string
+	if conditionContext != nil && len(conditionContext) > 0 {
+		ctxBytes, err := common.Marshal(conditionContext)
+		if err != nil {
+			return "", fmt.Errorf("failed to marshal condition context: %v", err)
+		}
+		contextJSON = string(ctxBytes)
+	}
+
 	result := jsonStr
 	for _, op := range operations {
 		// 检查条件是否满足
-		ok, err := checkConditions(result, op.Conditions, op.Logic)
+		ok, err := checkConditions(result, contextJSON, op.Conditions, op.Logic)
 		if err != nil {
 			return "", err
 		}
@@ -414,7 +426,7 @@ func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (str
 	var currentMap, newMap map[string]interface{}
 
 	// 解析当前值
-	if err := json.Unmarshal([]byte(current.Raw), &currentMap); err != nil {
+	if err := common.Unmarshal([]byte(current.Raw), &currentMap); err != nil {
 		return "", err
 	}
 	// 解析新值
@@ -422,8 +434,8 @@ func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (str
 	case map[string]interface{}:
 		newMap = v
 	default:
-		jsonBytes, _ := json.Marshal(v)
-		if err := json.Unmarshal(jsonBytes, &newMap); err != nil {
+		jsonBytes, _ := common.Marshal(v)
+		if err := common.Unmarshal(jsonBytes, &newMap); err != nil {
 			return "", err
 		}
 	}
@@ -439,3 +451,31 @@ func mergeObjects(jsonStr, path string, value interface{}, keepOrigin bool) (str
 	}
 	return sjson.Set(jsonStr, path, result)
 }
+
+// BuildParamOverrideContext 提供 ApplyParamOverride 可用的上下文信息。
+// 目前内置以下字段:
+//   - model:优先使用上游模型名(UpstreamModelName),若不存在则回落到原始模型名(OriginModelName)。
+//   - upstream_model:始终为通道映射后的上游模型名。
+//   - original_model:请求最初指定的模型名。
+func BuildParamOverrideContext(info *RelayInfo) map[string]interface{} {
+	if info == nil || info.ChannelMeta == nil {
+		return nil
+	}
+
+	ctx := make(map[string]interface{})
+	if info.UpstreamModelName != "" {
+		ctx["model"] = info.UpstreamModelName
+		ctx["upstream_model"] = info.UpstreamModelName
+	}
+	if info.OriginModelName != "" {
+		ctx["original_model"] = info.OriginModelName
+		if _, exists := ctx["model"]; !exists {
+			ctx["model"] = info.OriginModelName
+		}
+	}
+
+	if len(ctx) == 0 {
+		return nil
+	}
+	return ctx
+}

+ 1 - 1
relay/compatible_handler.go

@@ -144,7 +144,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
 
 		// apply param override
 		if len(info.ParamOverride) > 0 {
-			jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
+			jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
 			if err != nil {
 				return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
 			}

+ 8 - 0
relay/embedding_handler.go

@@ -49,6 +49,14 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
 	if err != nil {
 		return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
 	}
+
+	if len(info.ParamOverride) > 0 {
+		jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
+		if err != nil {
+			return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
+		}
+	}
+
 	logger.LogDebug(c, fmt.Sprintf("converted embedding request body: %s", string(jsonData)))
 	requestBody := bytes.NewBuffer(jsonData)
 	statusCodeMappingStr := c.GetString("status_code_mapping")

+ 1 - 1
relay/gemini_handler.go

@@ -156,7 +156,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
 
 		// apply param override
 		if len(info.ParamOverride) > 0 {
-			jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
+			jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
 			if err != nil {
 				return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
 			}

+ 1 - 1
relay/image_handler.go

@@ -69,7 +69,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
 
 			// apply param override
 			if len(info.ParamOverride) > 0 {
-				jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
+				jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
 				if err != nil {
 					return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
 				}

+ 1 - 1
relay/rerank_handler.go

@@ -60,7 +60,7 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
 
 		// apply param override
 		if len(info.ParamOverride) > 0 {
-			jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
+			jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
 			if err != nil {
 				return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
 			}

+ 1 - 1
relay/responses_handler.go

@@ -66,7 +66,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
 
 		// apply param override
 		if len(info.ParamOverride) > 0 {
-			jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride)
+			jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
 			if err != nil {
 				return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
 			}