Forráskód Böngészése

fix: gemini&claude tool call format #795 #766

[email protected] 10 hónapja
szülő
commit
13ab0f8e4f

+ 41 - 39
dto/openai_request.go

@@ -18,50 +18,52 @@ type FormatJsonSchema struct {
 }
 
 type GeneralOpenAIRequest struct {
-	Model               string          `json:"model,omitempty"`
-	Messages            []Message       `json:"messages,omitempty"`
-	Prompt              any             `json:"prompt,omitempty"`
-	Prefix              any             `json:"prefix,omitempty"`
-	Suffix              any             `json:"suffix,omitempty"`
-	Stream              bool            `json:"stream,omitempty"`
-	StreamOptions       *StreamOptions  `json:"stream_options,omitempty"`
-	MaxTokens           uint            `json:"max_tokens,omitempty"`
-	MaxCompletionTokens uint            `json:"max_completion_tokens,omitempty"`
-	ReasoningEffort     string          `json:"reasoning_effort,omitempty"`
-	Temperature         *float64        `json:"temperature,omitempty"`
-	TopP                float64         `json:"top_p,omitempty"`
-	TopK                int             `json:"top_k,omitempty"`
-	Stop                any             `json:"stop,omitempty"`
-	N                   int             `json:"n,omitempty"`
-	Input               any             `json:"input,omitempty"`
-	Instruction         string          `json:"instruction,omitempty"`
-	Size                string          `json:"size,omitempty"`
-	Functions           any             `json:"functions,omitempty"`
-	FrequencyPenalty    float64         `json:"frequency_penalty,omitempty"`
-	PresencePenalty     float64         `json:"presence_penalty,omitempty"`
-	ResponseFormat      *ResponseFormat `json:"response_format,omitempty"`
-	EncodingFormat      any             `json:"encoding_format,omitempty"`
-	Seed                float64         `json:"seed,omitempty"`
-	Tools               []ToolCall      `json:"tools,omitempty"`
-	ToolChoice          any             `json:"tool_choice,omitempty"`
-	User                string          `json:"user,omitempty"`
-	LogProbs            bool            `json:"logprobs,omitempty"`
-	TopLogProbs         int             `json:"top_logprobs,omitempty"`
-	Dimensions          int             `json:"dimensions,omitempty"`
-	Modalities          any             `json:"modalities,omitempty"`
-	Audio               any             `json:"audio,omitempty"`
-	ExtraBody           any             `json:"extra_body,omitempty"`
+	Model               string            `json:"model,omitempty"`
+	Messages            []Message         `json:"messages,omitempty"`
+	Prompt              any               `json:"prompt,omitempty"`
+	Prefix              any               `json:"prefix,omitempty"`
+	Suffix              any               `json:"suffix,omitempty"`
+	Stream              bool              `json:"stream,omitempty"`
+	StreamOptions       *StreamOptions    `json:"stream_options,omitempty"`
+	MaxTokens           uint              `json:"max_tokens,omitempty"`
+	MaxCompletionTokens uint              `json:"max_completion_tokens,omitempty"`
+	ReasoningEffort     string            `json:"reasoning_effort,omitempty"`
+	Temperature         *float64          `json:"temperature,omitempty"`
+	TopP                float64           `json:"top_p,omitempty"`
+	TopK                int               `json:"top_k,omitempty"`
+	Stop                any               `json:"stop,omitempty"`
+	N                   int               `json:"n,omitempty"`
+	Input               any               `json:"input,omitempty"`
+	Instruction         string            `json:"instruction,omitempty"`
+	Size                string            `json:"size,omitempty"`
+	Functions           any               `json:"functions,omitempty"`
+	FrequencyPenalty    float64           `json:"frequency_penalty,omitempty"`
+	PresencePenalty     float64           `json:"presence_penalty,omitempty"`
+	ResponseFormat      *ResponseFormat   `json:"response_format,omitempty"`
+	EncodingFormat      any               `json:"encoding_format,omitempty"`
+	Seed                float64           `json:"seed,omitempty"`
+	Tools               []ToolCallRequest `json:"tools,omitempty"`
+	ToolChoice          any               `json:"tool_choice,omitempty"`
+	User                string            `json:"user,omitempty"`
+	LogProbs            bool              `json:"logprobs,omitempty"`
+	TopLogProbs         int               `json:"top_logprobs,omitempty"`
+	Dimensions          int               `json:"dimensions,omitempty"`
+	Modalities          any               `json:"modalities,omitempty"`
+	Audio               any               `json:"audio,omitempty"`
+	ExtraBody           any               `json:"extra_body,omitempty"`
 }
 
-type OpenAITools struct {
-	Type     string         `json:"type"`
-	Function OpenAIFunction `json:"function"`
+type ToolCallRequest struct {
+	ID       string          `json:"id,omitempty"`
+	Type     string          `json:"type"`
+	Function FunctionRequest `json:"function"`
 }
 
-type OpenAIFunction struct {
+type FunctionRequest struct {
 	Description string `json:"description,omitempty"`
 	Name        string `json:"name"`
 	Parameters  any    `json:"parameters,omitempty"`
+	Arguments   string `json:"arguments,omitempty"`
 }
 
 type StreamOptions struct {
@@ -137,11 +139,11 @@ func (m *Message) SetPrefix(prefix bool) {
 	m.Prefix = &prefix
 }
 
-func (m *Message) ParseToolCalls() []ToolCall {
+func (m *Message) ParseToolCalls() []ToolCallRequest {
 	if m.ToolCalls == nil {
 		return nil
 	}
-	var toolCalls []ToolCall
+	var toolCalls []ToolCallRequest
 	if err := json.Unmarshal(m.ToolCalls, &toolCalls); err == nil {
 		return toolCalls
 	}

+ 12 - 12
dto/openai_response.go

@@ -62,10 +62,10 @@ type ChatCompletionsStreamResponseChoice struct {
 }
 
 type ChatCompletionsStreamResponseChoiceDelta struct {
-	Content          *string    `json:"content,omitempty"`
-	ReasoningContent *string    `json:"reasoning_content,omitempty"`
-	Role             string     `json:"role,omitempty"`
-	ToolCalls        []ToolCall `json:"tool_calls,omitempty"`
+	Content          *string            `json:"content,omitempty"`
+	ReasoningContent *string            `json:"reasoning_content,omitempty"`
+	Role             string             `json:"role,omitempty"`
+	ToolCalls        []ToolCallResponse `json:"tool_calls,omitempty"`
 }
 
 func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) {
@@ -90,24 +90,24 @@ func (c *ChatCompletionsStreamResponseChoiceDelta) SetReasoningContent(s string)
 	c.ReasoningContent = &s
 }
 
-type ToolCall struct {
+type ToolCallResponse struct {
 	// Index is not nil only in chat completion chunk object
-	Index    *int         `json:"index,omitempty"`
-	ID       string       `json:"id,omitempty"`
-	Type     any          `json:"type"`
-	Function FunctionCall `json:"function"`
+	Index    *int             `json:"index,omitempty"`
+	ID       string           `json:"id,omitempty"`
+	Type     any              `json:"type"`
+	Function FunctionResponse `json:"function"`
 }
 
-func (c *ToolCall) SetIndex(i int) {
+func (c *ToolCallResponse) SetIndex(i int) {
 	c.Index = &i
 }
 
-type FunctionCall struct {
+type FunctionResponse struct {
 	Description string `json:"description,omitempty"`
 	Name        string `json:"name,omitempty"`
 	// call function with arguments in JSON format
 	Parameters any    `json:"parameters,omitempty"` // request
-	Arguments  string `json:"arguments"`
+	Arguments  string `json:"arguments"`            // response
 }
 
 type ChatCompletionsStreamResponse struct {

+ 8 - 8
relay/channel/claude/relay-claude.go

@@ -296,7 +296,7 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
 	response.Object = "chat.completion.chunk"
 	response.Model = claudeResponse.Model
 	response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0)
-	tools := make([]dto.ToolCall, 0)
+	tools := make([]dto.ToolCallResponse, 0)
 	var choice dto.ChatCompletionsStreamResponseChoice
 	if reqMode == RequestModeCompletion {
 		choice.Delta.SetContentString(claudeResponse.Completion)
@@ -315,10 +315,10 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
 			if claudeResponse.ContentBlock != nil {
 				//choice.Delta.SetContentString(claudeResponse.ContentBlock.Text)
 				if claudeResponse.ContentBlock.Type == "tool_use" {
-					tools = append(tools, dto.ToolCall{
+					tools = append(tools, dto.ToolCallResponse{
 						ID:   claudeResponse.ContentBlock.Id,
 						Type: "function",
-						Function: dto.FunctionCall{
+						Function: dto.FunctionResponse{
 							Name:      claudeResponse.ContentBlock.Name,
 							Arguments: "",
 						},
@@ -333,8 +333,8 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
 				choice.Delta.SetContentString(claudeResponse.Delta.Text)
 				switch claudeResponse.Delta.Type {
 				case "input_json_delta":
-					tools = append(tools, dto.ToolCall{
-						Function: dto.FunctionCall{
+					tools = append(tools, dto.ToolCallResponse{
+						Function: dto.FunctionResponse{
 							Arguments: claudeResponse.Delta.PartialJson,
 						},
 					})
@@ -382,7 +382,7 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
 	if len(claudeResponse.Content) > 0 {
 		responseText = claudeResponse.Content[0].Text
 	}
-	tools := make([]dto.ToolCall, 0)
+	tools := make([]dto.ToolCallResponse, 0)
 	thinkingContent := ""
 
 	if reqMode == RequestModeCompletion {
@@ -403,10 +403,10 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
 			switch message.Type {
 			case "tool_use":
 				args, _ := json.Marshal(message.Input)
-				tools = append(tools, dto.ToolCall{
+				tools = append(tools, dto.ToolCallResponse{
 					ID:   message.Id,
 					Type: "function", // compatible with other OpenAI derivative applications
-					Function: dto.FunctionCall{
+					Function: dto.FunctionResponse{
 						Name:      message.Name,
 						Arguments: string(args),
 					},

+ 7 - 7
relay/channel/gemini/relay-gemini.go

@@ -43,7 +43,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
 
 	// openaiContent.FuncToToolCalls()
 	if textRequest.Tools != nil {
-		functions := make([]dto.FunctionCall, 0, len(textRequest.Tools))
+		functions := make([]dto.FunctionRequest, 0, len(textRequest.Tools))
 		googleSearch := false
 		codeExecution := false
 		for _, tool := range textRequest.Tools {
@@ -338,7 +338,7 @@ func unescapeMapOrSlice(data interface{}) interface{} {
 	return data
 }
 
-func getToolCall(item *GeminiPart) *dto.ToolCall {
+func getResponseToolCall(item *GeminiPart) *dto.ToolCallResponse {
 	var argsBytes []byte
 	var err error
 	if result, ok := item.FunctionCall.Arguments.(map[string]interface{}); ok {
@@ -350,10 +350,10 @@ func getToolCall(item *GeminiPart) *dto.ToolCall {
 	if err != nil {
 		return nil
 	}
-	return &dto.ToolCall{
+	return &dto.ToolCallResponse{
 		ID:   fmt.Sprintf("call_%s", common.GetUUID()),
 		Type: "function",
-		Function: dto.FunctionCall{
+		Function: dto.FunctionResponse{
 			Arguments: string(argsBytes),
 			Name:      item.FunctionCall.FunctionName,
 		},
@@ -380,11 +380,11 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
 		}
 		if len(candidate.Content.Parts) > 0 {
 			var texts []string
-			var toolCalls []dto.ToolCall
+			var toolCalls []dto.ToolCallResponse
 			for _, part := range candidate.Content.Parts {
 				if part.FunctionCall != nil {
 					choice.FinishReason = constant.FinishReasonToolCalls
-					if call := getToolCall(&part); call != nil {
+					if call := getResponseToolCall(&part); call != nil {
 						toolCalls = append(toolCalls, *call)
 					}
 				} else {
@@ -457,7 +457,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.C
 		for _, part := range candidate.Content.Parts {
 			if part.FunctionCall != nil {
 				isTools = true
-				if call := getToolCall(&part); call != nil {
+				if call := getResponseToolCall(&part); call != nil {
 					call.SetIndex(len(choice.Delta.ToolCalls))
 					choice.Delta.ToolCalls = append(choice.Delta.ToolCalls, *call)
 				}

+ 16 - 16
relay/channel/ollama/dto.go

@@ -3,22 +3,22 @@ package ollama
 import "one-api/dto"
 
 type OllamaRequest struct {
-	Model            string             `json:"model,omitempty"`
-	Messages         []dto.Message      `json:"messages,omitempty"`
-	Stream           bool               `json:"stream,omitempty"`
-	Temperature      *float64           `json:"temperature,omitempty"`
-	Seed             float64            `json:"seed,omitempty"`
-	Topp             float64            `json:"top_p,omitempty"`
-	TopK             int                `json:"top_k,omitempty"`
-	Stop             any                `json:"stop,omitempty"`
-	MaxTokens        uint               `json:"max_tokens,omitempty"`
-	Tools            []dto.ToolCall     `json:"tools,omitempty"`
-	ResponseFormat   any                `json:"response_format,omitempty"`
-	FrequencyPenalty float64            `json:"frequency_penalty,omitempty"`
-	PresencePenalty  float64            `json:"presence_penalty,omitempty"`
-	Suffix           any                `json:"suffix,omitempty"`
-	StreamOptions    *dto.StreamOptions `json:"stream_options,omitempty"`
-	Prompt           any                `json:"prompt,omitempty"`
+	Model            string                `json:"model,omitempty"`
+	Messages         []dto.Message         `json:"messages,omitempty"`
+	Stream           bool                  `json:"stream,omitempty"`
+	Temperature      *float64              `json:"temperature,omitempty"`
+	Seed             float64               `json:"seed,omitempty"`
+	Topp             float64               `json:"top_p,omitempty"`
+	TopK             int                   `json:"top_k,omitempty"`
+	Stop             any                   `json:"stop,omitempty"`
+	MaxTokens        uint                  `json:"max_tokens,omitempty"`
+	Tools            []dto.ToolCallRequest `json:"tools,omitempty"`
+	ResponseFormat   any                   `json:"response_format,omitempty"`
+	FrequencyPenalty float64               `json:"frequency_penalty,omitempty"`
+	PresencePenalty  float64               `json:"presence_penalty,omitempty"`
+	Suffix           any                   `json:"suffix,omitempty"`
+	StreamOptions    *dto.StreamOptions    `json:"stream_options,omitempty"`
+	Prompt           any                   `json:"prompt,omitempty"`
 }
 
 type Options struct {

+ 1 - 7
service/token_counter.go

@@ -1,7 +1,6 @@
 package service
 
 import (
-	"encoding/json"
 	"errors"
 	"fmt"
 	"image"
@@ -170,12 +169,7 @@ func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenA
 	}
 	tkm += msgTokens
 	if request.Tools != nil {
-		toolsData, _ := json.Marshal(request.Tools)
-		var openaiTools []dto.OpenAITools
-		err := json.Unmarshal(toolsData, &openaiTools)
-		if err != nil {
-			return 0, errors.New(fmt.Sprintf("count_tools_token_fail: %s", err.Error()))
-		}
+		openaiTools := request.Tools
 		countStr := ""
 		for _, tool := range openaiTools {
 			countStr = tool.Function.Name