Procházet zdrojové kódy

Merge pull request #1384 from QuantumNous/RequestOpenAI2ClaudeMessage

feat: 改进 RequestOpenAI2ClaudeMessage 和添加 claude web search 计费
Calcium-Ion před 5 měsíci
rodič
revize
6b3f1ab0e4

+ 83 - 4
dto/claude.go

@@ -159,6 +159,27 @@ type InputSchema struct {
 	Required   any    `json:"required,omitempty"`
 }
 
+type ClaudeWebSearchTool struct {
+	Type         string                       `json:"type"`
+	Name         string                       `json:"name"`
+	MaxUses      int                          `json:"max_uses,omitempty"`
+	UserLocation *ClaudeWebSearchUserLocation `json:"user_location,omitempty"`
+}
+
+type ClaudeWebSearchUserLocation struct {
+	Type     string `json:"type"`
+	Timezone string `json:"timezone,omitempty"`
+	Country  string `json:"country,omitempty"`
+	Region   string `json:"region,omitempty"`
+	City     string `json:"city,omitempty"`
+}
+
+type ClaudeToolChoice struct {
+	Type                   string `json:"type"`
+	Name                   string `json:"name,omitempty"`
+	DisableParallelToolUse bool   `json:"disable_parallel_tool_use,omitempty"`
+}
+
 type ClaudeRequest struct {
 	Model             string          `json:"model"`
 	Prompt            string          `json:"prompt,omitempty"`
@@ -177,6 +198,59 @@ type ClaudeRequest struct {
 	Thinking   *Thinking `json:"thinking,omitempty"`
 }
 
+// AddTool 添加工具到请求中
+func (c *ClaudeRequest) AddTool(tool any) {
+	if c.Tools == nil {
+		c.Tools = make([]any, 0)
+	}
+
+	switch tools := c.Tools.(type) {
+	case []any:
+		c.Tools = append(tools, tool)
+	default:
+		// 如果Tools不是[]any类型,重新初始化为[]any
+		c.Tools = []any{tool}
+	}
+}
+
+// GetTools 获取工具列表
+func (c *ClaudeRequest) GetTools() []any {
+	if c.Tools == nil {
+		return nil
+	}
+
+	switch tools := c.Tools.(type) {
+	case []any:
+		return tools
+	default:
+		return nil
+	}
+}
+
+// ProcessTools 处理工具列表,支持类型断言
+func ProcessTools(tools []any) ([]*Tool, []*ClaudeWebSearchTool) {
+	var normalTools []*Tool
+	var webSearchTools []*ClaudeWebSearchTool
+
+	for _, tool := range tools {
+		switch t := tool.(type) {
+		case *Tool:
+			normalTools = append(normalTools, t)
+		case *ClaudeWebSearchTool:
+			webSearchTools = append(webSearchTools, t)
+		case Tool:
+			normalTools = append(normalTools, &t)
+		case ClaudeWebSearchTool:
+			webSearchTools = append(webSearchTools, &t)
+		default:
+			// 未知类型,跳过
+			continue
+		}
+	}
+
+	return normalTools, webSearchTools
+}
+
 type Thinking struct {
 	Type         string `json:"type"`
 	BudgetTokens *int   `json:"budget_tokens,omitempty"`
@@ -251,8 +325,13 @@ func (c *ClaudeResponse) GetIndex() int {
 }
 
 type ClaudeUsage struct {
-	InputTokens              int `json:"input_tokens"`
-	CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
-	CacheReadInputTokens     int `json:"cache_read_input_tokens"`
-	OutputTokens             int `json:"output_tokens"`
+	InputTokens              int                  `json:"input_tokens"`
+	CacheCreationInputTokens int                  `json:"cache_creation_input_tokens"`
+	CacheReadInputTokens     int                  `json:"cache_read_input_tokens"`
+	OutputTokens             int                  `json:"output_tokens"`
+	ServerToolUse            *ClaudeServerToolUse `json:"server_tool_use"`
+}
+
+type ClaudeServerToolUse struct {
+	WebSearchRequests int `json:"web_search_requests"`
 }

+ 141 - 2
relay/channel/claude/relay-claude.go

@@ -18,6 +18,12 @@ import (
 	"github.com/gin-gonic/gin"
 )
 
+const (
+	WebSearchMaxUsesLow    = 1
+	WebSearchMaxUsesMedium = 5
+	WebSearchMaxUsesHigh   = 10
+)
+
 func stopReasonClaude2OpenAI(reason string) string {
 	switch reason {
 	case "stop_sequence":
@@ -65,7 +71,7 @@ func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *dto.Cla
 }
 
 func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.ClaudeRequest, error) {
-	claudeTools := make([]dto.Tool, 0, len(textRequest.Tools))
+	claudeTools := make([]any, 0, len(textRequest.Tools))
 
 	for _, tool := range textRequest.Tools {
 		if params, ok := tool.Function.Parameters.(map[string]any); ok {
@@ -85,10 +91,62 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
 				}
 				claudeTool.InputSchema[s] = a
 			}
-			claudeTools = append(claudeTools, claudeTool)
+			claudeTools = append(claudeTools, &claudeTool)
 		}
 	}
 
+	// Web search tool
+	// https://docs.anthropic.com/en/docs/agents-and-tools/tool-use/web-search-tool
+	if textRequest.WebSearchOptions != nil {
+		webSearchTool := dto.ClaudeWebSearchTool{
+			Type: "web_search_20250305",
+			Name: "web_search",
+		}
+
+		// 处理 user_location
+		if textRequest.WebSearchOptions.UserLocation != nil {
+			anthropicUserLocation := &dto.ClaudeWebSearchUserLocation{
+				Type: "approximate", // 固定为 "approximate"
+			}
+
+			// 解析 UserLocation JSON
+			var userLocationMap map[string]interface{}
+			if err := json.Unmarshal(textRequest.WebSearchOptions.UserLocation, &userLocationMap); err == nil {
+				// 检查是否有 approximate 字段
+				if approximateData, ok := userLocationMap["approximate"].(map[string]interface{}); ok {
+					if timezone, ok := approximateData["timezone"].(string); ok && timezone != "" {
+						anthropicUserLocation.Timezone = timezone
+					}
+					if country, ok := approximateData["country"].(string); ok && country != "" {
+						anthropicUserLocation.Country = country
+					}
+					if region, ok := approximateData["region"].(string); ok && region != "" {
+						anthropicUserLocation.Region = region
+					}
+					if city, ok := approximateData["city"].(string); ok && city != "" {
+						anthropicUserLocation.City = city
+					}
+				}
+			}
+
+			webSearchTool.UserLocation = anthropicUserLocation
+		}
+
+		// 处理 search_context_size 转换为 max_uses
+		if textRequest.WebSearchOptions.SearchContextSize != "" {
+			switch textRequest.WebSearchOptions.SearchContextSize {
+			case "low":
+				webSearchTool.MaxUses = WebSearchMaxUsesLow
+			case "medium":
+				webSearchTool.MaxUses = WebSearchMaxUsesMedium
+			case "high":
+				webSearchTool.MaxUses = WebSearchMaxUsesHigh
+			}
+		}
+
+		claudeTools = append(claudeTools, &webSearchTool)
+	}
+
 	claudeRequest := dto.ClaudeRequest{
 		Model:         textRequest.Model,
 		MaxTokens:     textRequest.MaxTokens,
@@ -100,6 +158,14 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
 		Tools:         claudeTools,
 	}
 
+	// 处理 tool_choice 和 parallel_tool_calls
+	if textRequest.ToolChoice != nil || textRequest.ParallelTooCalls != nil {
+		claudeToolChoice := mapToolChoice(textRequest.ToolChoice, textRequest.ParallelTooCalls)
+		if claudeToolChoice != nil {
+			claudeRequest.ToolChoice = claudeToolChoice
+		}
+	}
+
 	if claudeRequest.MaxTokens == 0 {
 		claudeRequest.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(textRequest.Model))
 	}
@@ -124,6 +190,27 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
 		claudeRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking")
 	}
 
+	if textRequest.ReasoningEffort != "" {
+		switch textRequest.ReasoningEffort {
+		case "low":
+			claudeRequest.Thinking = &dto.Thinking{
+				Type:         "enabled",
+				BudgetTokens: common.GetPointer[int](1280),
+			}
+		case "medium":
+			claudeRequest.Thinking = &dto.Thinking{
+				Type:         "enabled",
+				BudgetTokens: common.GetPointer[int](2048),
+			}
+		case "high":
+			claudeRequest.Thinking = &dto.Thinking{
+				Type:         "enabled",
+				BudgetTokens: common.GetPointer[int](4096),
+			}
+		}
+	}
+
+	// 指定了 reasoning 参数,覆盖 budgetTokens
 	if textRequest.Reasoning != nil {
 		var reasoning openrouter.RequestReasoning
 		if err := common.Unmarshal(textRequest.Reasoning, &reasoning); err != nil {
@@ -645,6 +732,10 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
 		responseData = data
 	}
 
+	if claudeResponse.Usage.ServerToolUse != nil && claudeResponse.Usage.ServerToolUse.WebSearchRequests > 0 {
+		c.Set("claude_web_search_requests", claudeResponse.Usage.ServerToolUse.WebSearchRequests)
+	}
+
 	common.IOCopyBytesGracefully(c, nil, responseData)
 	return nil
 }
@@ -672,3 +763,51 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *r
 	}
 	return nil, claudeInfo.Usage
 }
+
+func mapToolChoice(toolChoice any, parallelToolCalls *bool) *dto.ClaudeToolChoice {
+	var claudeToolChoice *dto.ClaudeToolChoice
+
+	// 处理 tool_choice 字符串值
+	if toolChoiceStr, ok := toolChoice.(string); ok {
+		switch toolChoiceStr {
+		case "auto":
+			claudeToolChoice = &dto.ClaudeToolChoice{
+				Type: "auto",
+			}
+		case "required":
+			claudeToolChoice = &dto.ClaudeToolChoice{
+				Type: "any",
+			}
+		case "none":
+			claudeToolChoice = &dto.ClaudeToolChoice{
+				Type: "none",
+			}
+		}
+	} else if toolChoiceMap, ok := toolChoice.(map[string]interface{}); ok {
+		// 处理 tool_choice 对象值
+		if function, ok := toolChoiceMap["function"].(map[string]interface{}); ok {
+			if toolName, ok := function["name"].(string); ok {
+				claudeToolChoice = &dto.ClaudeToolChoice{
+					Type: "tool",
+					Name: toolName,
+				}
+			}
+		}
+	}
+
+	// 处理 parallel_tool_calls
+	if parallelToolCalls != nil {
+		if claudeToolChoice == nil {
+			// 如果没有 tool_choice,但有 parallel_tool_calls,创建默认的 auto 类型
+			claudeToolChoice = &dto.ClaudeToolChoice{
+				Type: "auto",
+			}
+		}
+
+		// 设置 disable_parallel_tool_use
+		// 如果 parallel_tool_calls 为 true,则 disable_parallel_tool_use 为 false
+		claudeToolChoice.DisableParallelToolUse = !*parallelToolCalls
+	}
+
+	return claudeToolChoice
+}

+ 16 - 0
relay/relay-text.go

@@ -379,6 +379,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 	// openai web search 工具计费
 	var dWebSearchQuota decimal.Decimal
 	var webSearchPrice float64
+	// response api 格式工具计费
 	if relayInfo.ResponsesUsageInfo != nil {
 		if webSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview]; exists && webSearchTool.CallCount > 0 {
 			// 计算 web search 调用的配额 (配额 = 价格 * 调用次数 / 1000 * 分组倍率)
@@ -401,6 +402,17 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 		extraContent += fmt.Sprintf("Web Search 调用 1 次,上下文大小 %s,调用花费 %s",
 			searchContextSize, dWebSearchQuota.String())
 	}
+	// claude web search tool 计费
+	var dClaudeWebSearchQuota decimal.Decimal
+	var claudeWebSearchPrice float64
+	claudeWebSearchCallCount := ctx.GetInt("claude_web_search_requests")
+	if claudeWebSearchCallCount > 0 {
+		claudeWebSearchPrice = operation_setting.GetClaudeWebSearchPricePerThousand()
+		dClaudeWebSearchQuota = decimal.NewFromFloat(claudeWebSearchPrice).
+			Div(decimal.NewFromInt(1000)).Mul(dGroupRatio).Mul(dQuotaPerUnit).Mul(decimal.NewFromInt(int64(claudeWebSearchCallCount)))
+		extraContent += fmt.Sprintf("Claude Web Search 调用 %d 次,调用花费 %s",
+			claudeWebSearchCallCount, dClaudeWebSearchQuota.String())
+	}
 	// file search tool 计费
 	var dFileSearchQuota decimal.Decimal
 	var fileSearchPrice float64
@@ -524,6 +536,10 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
 			other["web_search_call_count"] = 1
 			other["web_search_price"] = webSearchPrice
 		}
+	} else if !dClaudeWebSearchQuota.IsZero() {
+		other["web_search"] = true
+		other["web_search_call_count"] = claudeWebSearchCallCount
+		other["web_search_price"] = claudeWebSearchPrice
 	}
 	if !dFileSearchQuota.IsZero() && relayInfo.ResponsesUsageInfo != nil {
 		if fileSearchTool, exists := relayInfo.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolFileSearch]; exists {

+ 9 - 0
setting/operation_setting/tools.go

@@ -23,6 +23,15 @@ const (
 	Gemini20FlashInputAudioPrice            = 0.70
 )
 
+const (
+	// Claude Web search
+	ClaudeWebSearchPrice = 10.00
+)
+
+func GetClaudeWebSearchPricePerThousand() float64 {
+	return ClaudeWebSearchPrice
+}
+
 func GetWebSearchPricePerThousand(modelName string, contextSize string) float64 {
 	// 确定模型类型
 	// https://platform.openai.com/docs/pricing Web search 价格按模型类型和 search context size 收费