瀏覽代碼

Merge remote-tracking branch 'origin/alpha' into alpha

t0ng7u 4 月之前
父節點
當前提交
aa1f5c6e4e
共有 4 個文件被更改,包括 156 次插入63 次删除
  1. 1 1
      controller/channel-test.go
  2. 38 33
      controller/relay.go
  3. 115 29
      relay/common/override.go
  4. 2 0
      relay/constant/relay_mode.go

+ 1 - 1
controller/channel-test.go

@@ -445,7 +445,7 @@ func testAllChannels(notify bool) error {
 
 			// disable channel
 			if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() {
-				go processChannelError(result.context, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
+				processChannelError(result.context, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
 			}
 
 			// enable channel

+ 38 - 33
controller/relay.go

@@ -3,6 +3,7 @@ package controller
 import (
 	"bytes"
 	"fmt"
+	"github.com/bytedance/gopkg/util/gopool"
 	"io"
 	"log"
 	"net/http"
@@ -61,8 +62,8 @@ func geminiRelayHandler(c *gin.Context, info *relaycommon.RelayInfo) *types.NewA
 func Relay(c *gin.Context, relayFormat types.RelayFormat) {
 
 	requestId := c.GetString(common.RequestIdKey)
-	group := c.GetString("group")
-	originalModel := c.GetString("original_model")
+	group := common.GetContextKeyString(c, constant.ContextKeyUsingGroup)
+	originalModel := common.GetContextKeyString(c, constant.ContextKeyOriginalModel)
 
 	var (
 		newAPIError *types.NewAPIError
@@ -172,35 +173,9 @@ func Relay(c *gin.Context, relayFormat types.RelayFormat) {
 
 		if newAPIError == nil {
 			return
-		} else {
-			if constant.ErrorLogEnabled && types.IsRecordErrorLog(newAPIError) {
-				// 保存错误日志到mysql中
-				userId := c.GetInt("id")
-				tokenName := c.GetString("token_name")
-				modelName := c.GetString("original_model")
-				tokenId := c.GetInt("token_id")
-				userGroup := c.GetString("group")
-				channelId := c.GetInt("channel_id")
-				other := make(map[string]interface{})
-				other["error_type"] = newAPIError.GetErrorType()
-				other["error_code"] = newAPIError.GetErrorCode()
-				other["status_code"] = newAPIError.StatusCode
-				other["channel_id"] = channelId
-				other["channel_name"] = c.GetString("channel_name")
-				other["channel_type"] = c.GetInt("channel_type")
-				adminInfo := make(map[string]interface{})
-				adminInfo["use_channel"] = c.GetStringSlice("use_channel")
-				isMultiKey := common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey)
-				if isMultiKey {
-					adminInfo["is_multi_key"] = true
-					adminInfo["multi_key_index"] = common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex)
-				}
-				other["admin_info"] = adminInfo
-				model.RecordErrorLog(c, userId, channelId, modelName, tokenName, newAPIError.MaskSensitiveError(), tokenId, 0, false, userGroup, other)
-			}
 		}
 
-		go processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
+		processChannelError(c, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(c, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
 
 		if !shouldRetry(c, newAPIError, common.RetryTimes-i) {
 			break
@@ -298,12 +273,42 @@ func shouldRetry(c *gin.Context, openaiErr *types.NewAPIError, retryTimes int) b
 }
 
 func processChannelError(c *gin.Context, channelError types.ChannelError, err *types.NewAPIError) {
-	// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
-	// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
 	logger.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelError.ChannelId, err.StatusCode, err.Error()))
-	if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan {
-		service.DisableChannel(channelError, err.Error())
+
+	gopool.Go(func() {
+		// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
+		// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
+		if service.ShouldDisableChannel(channelError.ChannelId, err) && channelError.AutoBan {
+			service.DisableChannel(channelError, err.Error())
+		}
+	})
+
+	if constant.ErrorLogEnabled && types.IsRecordErrorLog(err) {
+		// 保存错误日志到mysql中
+		userId := c.GetInt("id")
+		tokenName := c.GetString("token_name")
+		modelName := c.GetString("original_model")
+		tokenId := c.GetInt("token_id")
+		userGroup := c.GetString("group")
+		channelId := c.GetInt("channel_id")
+		other := make(map[string]interface{})
+		other["error_type"] = err.GetErrorType()
+		other["error_code"] = err.GetErrorCode()
+		other["status_code"] = err.StatusCode
+		other["channel_id"] = channelId
+		other["channel_name"] = c.GetString("channel_name")
+		other["channel_type"] = c.GetInt("channel_type")
+		adminInfo := make(map[string]interface{})
+		adminInfo["use_channel"] = c.GetStringSlice("use_channel")
+		isMultiKey := common.GetContextKeyBool(c, constant.ContextKeyChannelIsMultiKey)
+		if isMultiKey {
+			adminInfo["is_multi_key"] = true
+			adminInfo["multi_key_index"] = common.GetContextKeyInt(c, constant.ContextKeyChannelMultiKeyIndex)
+		}
+		other["admin_info"] = adminInfo
+		model.RecordErrorLog(c, userId, channelId, modelName, tokenName, err.MaskSensitiveError(), tokenId, 0, false, userGroup, other)
 	}
+
 }
 
 func RelayMidjourney(c *gin.Context) {

+ 115 - 29
relay/common/override.go

@@ -9,9 +9,11 @@ import (
 )
 
 type ConditionOperation struct {
-	Path  string `json:"path"`  // JSON路径
-	Mode  string `json:"mode"`  // full, prefix, suffix, contains
-	Value string `json:"value"` // 匹配的值
+	Path           string      `json:"path"`             // JSON路径
+	Mode           string      `json:"mode"`             // full, prefix, suffix, contains, gt, gte, lt, lte
+	Value          interface{} `json:"value"`            // 匹配的值
+	Invert         bool        `json:"invert"`           // 反选功能,true表示取反结果
+	PassMissingKey bool        `json:"pass_missing_key"` // 未获取到json key时的行为
 }
 
 type ParamOperation struct {
@@ -34,11 +36,7 @@ func ApplyParamOverride(jsonData []byte, paramOverride map[string]interface{}) (
 	if operations, ok := tryParseOperations(paramOverride); ok {
 		// 使用新方法
 		result, err := applyOperations(string(jsonData), operations)
-		if err != nil {
-			// 新方法失败,回退到旧方法
-			return applyOperationsLegacy(jsonData, paramOverride)
-		}
-		return []byte(result), nil
+		return []byte(result), err
 	}
 
 	// 直接使用旧方法
@@ -95,9 +93,15 @@ func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation,
 									if mode, ok := condMap["mode"].(string); ok {
 										condition.Mode = mode
 									}
-									if value, ok := condMap["value"].(string); ok {
+									if value, ok := condMap["value"]; ok {
 										condition.Value = value
 									}
+									if invert, ok := condMap["invert"].(bool); ok {
+										condition.Invert = invert
+									}
+									if passMissingKey, ok := condMap["pass_missing_key"].(bool); ok {
+										condition.PassMissingKey = passMissingKey
+									}
 									operation.Conditions = append(operation.Conditions, condition)
 								}
 							}
@@ -116,52 +120,131 @@ func tryParseOperations(paramOverride map[string]interface{}) ([]ParamOperation,
 	return nil, false
 }
 
-func checkConditions(jsonStr string, conditions []ConditionOperation, logic string) bool {
+func checkConditions(jsonStr string, conditions []ConditionOperation, logic string) (bool, error) {
 	if len(conditions) == 0 {
-		return true // 没有条件,直接通过
+		return true, nil // 没有条件,直接通过
 	}
 	results := make([]bool, len(conditions))
-
 	for i, condition := range conditions {
-		results[i] = checkSingleCondition(jsonStr, condition)
+		result, err := checkSingleCondition(jsonStr, condition)
+		if err != nil {
+			return false, err
+		}
+		results[i] = result
 	}
+
 	if strings.ToUpper(logic) == "AND" {
 		for _, result := range results {
 			if !result {
-				return false
+				return false, nil
 			}
 		}
-		return true
+		return true, nil
 	} else {
 		for _, result := range results {
 			if result {
-				return true
+				return true, nil
 			}
 		}
-		return false
+		return false, nil
 	}
 }
 
-func checkSingleCondition(jsonStr string, condition ConditionOperation) bool {
+func checkSingleCondition(jsonStr string, condition ConditionOperation) (bool, error) {
 	value := gjson.Get(jsonStr, condition.Path)
 	if !value.Exists() {
-		return false
+		if condition.PassMissingKey {
+			return true, nil
+		}
+		return false, nil
 	}
 
-	valueStr := value.String()
-	targetStr := condition.Value
+	// 利用gjson的类型解析
+	targetBytes, err := json.Marshal(condition.Value)
+	if err != nil {
+		return false, fmt.Errorf("failed to marshal condition value: %v", err)
+	}
+	targetValue := gjson.ParseBytes(targetBytes)
+
+	result, err := compareGjsonValues(value, targetValue, strings.ToLower(condition.Mode))
+	if err != nil {
+		return false, fmt.Errorf("comparison failed for path %s: %v", condition.Path, err)
+	}
+
+	if condition.Invert {
+		result = !result
+	}
+	return result, nil
+}
 
-	switch strings.ToLower(condition.Mode) {
+// compareGjsonValues 直接比较两个gjson.Result,支持所有比较模式
+func compareGjsonValues(jsonValue, targetValue gjson.Result, mode string) (bool, error) {
+	switch mode {
 	case "full":
-		return valueStr == targetStr
+		return compareEqual(jsonValue, targetValue)
 	case "prefix":
-		return strings.HasPrefix(valueStr, targetStr)
+		return strings.HasPrefix(jsonValue.String(), targetValue.String()), nil
 	case "suffix":
-		return strings.HasSuffix(valueStr, targetStr)
+		return strings.HasSuffix(jsonValue.String(), targetValue.String()), nil
 	case "contains":
-		return strings.Contains(valueStr, targetStr)
+		return strings.Contains(jsonValue.String(), targetValue.String()), nil
+	case "gt":
+		return compareNumeric(jsonValue, targetValue, "gt")
+	case "gte":
+		return compareNumeric(jsonValue, targetValue, "gte")
+	case "lt":
+		return compareNumeric(jsonValue, targetValue, "lt")
+	case "lte":
+		return compareNumeric(jsonValue, targetValue, "lte")
 	default:
-		return valueStr == targetStr // 默认精准匹配
+		return false, fmt.Errorf("unsupported comparison mode: %s", mode)
+	}
+}
+
+func compareEqual(jsonValue, targetValue gjson.Result) (bool, error) {
+	// 对布尔值特殊处理
+	if (jsonValue.Type == gjson.True || jsonValue.Type == gjson.False) &&
+		(targetValue.Type == gjson.True || targetValue.Type == gjson.False) {
+		return jsonValue.Bool() == targetValue.Bool(), nil
+	}
+
+	// 如果类型不同,报错
+	if jsonValue.Type != targetValue.Type {
+		return false, fmt.Errorf("compare for different types, got %v and %v", jsonValue.Type, targetValue.Type)
+	}
+
+	switch jsonValue.Type {
+	case gjson.True, gjson.False:
+		return jsonValue.Bool() == targetValue.Bool(), nil
+	case gjson.Number:
+		return jsonValue.Num == targetValue.Num, nil
+	case gjson.String:
+		return jsonValue.String() == targetValue.String(), nil
+	default:
+		return jsonValue.String() == targetValue.String(), nil
+	}
+}
+
+func compareNumeric(jsonValue, targetValue gjson.Result, operator string) (bool, error) {
+	// 只有数字类型才支持数值比较
+	if jsonValue.Type != gjson.Number || targetValue.Type != gjson.Number {
+		return false, fmt.Errorf("numeric comparison requires both values to be numbers, got %v and %v", jsonValue.Type, targetValue.Type)
+	}
+
+	jsonNum := jsonValue.Num
+	targetNum := targetValue.Num
+
+	switch operator {
+	case "gt":
+		return jsonNum > targetNum, nil
+	case "gte":
+		return jsonNum >= targetNum, nil
+	case "lt":
+		return jsonNum < targetNum, nil
+	case "lte":
+		return jsonNum <= targetNum, nil
+	default:
+		return false, fmt.Errorf("unsupported numeric operator: %s", operator)
 	}
 }
 
@@ -184,11 +267,14 @@ func applyOperations(jsonStr string, operations []ParamOperation) (string, error
 	result := jsonStr
 	for _, op := range operations {
 		// 检查条件是否满足
-		if !checkConditions(result, op.Conditions, op.Logic) {
+		ok, err := checkConditions(result, op.Conditions, op.Logic)
+		if err != nil {
+			return "", err
+		}
+		if !ok {
 			continue // 条件不满足,跳过当前操作
 		}
 
-		var err error
 		switch op.Mode {
 		case "delete":
 			result, err = sjson.Delete(result, op.Path)

+ 2 - 0
relay/constant/relay_mode.go

@@ -84,6 +84,8 @@ func Path2RelayMode(path string) int {
 		relayMode = RelayModeRealtime
 	} else if strings.HasPrefix(path, "/v1beta/models") || strings.HasPrefix(path, "/v1/models") {
 		relayMode = RelayModeGemini
+	} else if strings.HasPrefix(path, "/mj") {
+		relayMode = Path2RelayModeMidjourney(path)
 	}
 	return relayMode
 }