Răsfoiți Sursa

refactor: Update OpenAI request and message handling

- Changed the type of ToolCalls in the Message struct from `any` to `json.RawMessage` for better type safety and clarity.
- Introduced ParseToolCalls and SetToolCalls methods to handle ToolCalls more effectively, improving code readability and maintainability.
- Updated the ParseContent method to work with the new MediaContent type instead of MediaMessage, enhancing the structure of content parsing.
- Refactored Gemini relay functions to utilize the new ToolCalls handling methods, streamlining the integration with OpenAI and Gemini systems.
CalciumIon 1 an în urmă
părinte
comite
0c326556aa

+ 26 - 10
dto/openai_request.go

@@ -22,7 +22,7 @@ type GeneralOpenAIRequest struct {
 	StreamOptions       *StreamOptions  `json:"stream_options,omitempty"`
 	MaxTokens           uint            `json:"max_tokens,omitempty"`
 	MaxCompletionTokens uint            `json:"max_completion_tokens,omitempty"`
-	ReasoningEffort     string         `json:"reasoning_effort,omitempty"`
+	ReasoningEffort     string          `json:"reasoning_effort,omitempty"`
 	Temperature         float64         `json:"temperature,omitempty"`
 	TopP                float64         `json:"top_p,omitempty"`
 	TopK                int             `json:"top_k,omitempty"`
@@ -89,11 +89,27 @@ type Message struct {
 	Role       string          `json:"role"`
 	Content    json.RawMessage `json:"content"`
 	Name       *string         `json:"name,omitempty"`
-	ToolCalls  any             `json:"tool_calls,omitempty"`
+	ToolCalls  json.RawMessage `json:"tool_calls,omitempty"`
 	ToolCallId string          `json:"tool_call_id,omitempty"`
 }
 
-type MediaMessage struct {
+func (m Message) ParseToolCalls() []ToolCall {
+	if m.ToolCalls == nil {
+		return nil
+	}
+	var toolCalls []ToolCall
+	if err := json.Unmarshal(m.ToolCalls, &toolCalls); err == nil {
+		return toolCalls
+	}
+	return toolCalls
+}
+
+func (m Message) SetToolCalls(toolCalls any) {
+	toolCallsJson, _ := json.Marshal(toolCalls)
+	m.ToolCalls = toolCallsJson
+}
+
+type MediaContent struct {
 	Type       string `json:"type"`
 	Text       string `json:"text"`
 	ImageUrl   any    `json:"image_url,omitempty"`
@@ -137,11 +153,11 @@ func (m Message) IsStringContent() bool {
 	return false
 }
 
-func (m Message) ParseContent() []MediaMessage {
-	var contentList []MediaMessage
+func (m Message) ParseContent() []MediaContent {
+	var contentList []MediaContent
 	var stringContent string
 	if err := json.Unmarshal(m.Content, &stringContent); err == nil {
-		contentList = append(contentList, MediaMessage{
+		contentList = append(contentList, MediaContent{
 			Type: ContentTypeText,
 			Text: stringContent,
 		})
@@ -157,7 +173,7 @@ func (m Message) ParseContent() []MediaMessage {
 			switch contentMap["type"] {
 			case ContentTypeText:
 				if subStr, ok := contentMap["text"].(string); ok {
-					contentList = append(contentList, MediaMessage{
+					contentList = append(contentList, MediaContent{
 						Type: ContentTypeText,
 						Text: subStr,
 					})
@@ -170,7 +186,7 @@ func (m Message) ParseContent() []MediaMessage {
 					} else {
 						subObj["detail"] = "high"
 					}
-					contentList = append(contentList, MediaMessage{
+					contentList = append(contentList, MediaContent{
 						Type: ContentTypeImageURL,
 						ImageUrl: MessageImageUrl{
 							Url:    subObj["url"].(string),
@@ -178,7 +194,7 @@ func (m Message) ParseContent() []MediaMessage {
 						},
 					})
 				} else if url, ok := contentMap["image_url"].(string); ok {
-					contentList = append(contentList, MediaMessage{
+					contentList = append(contentList, MediaContent{
 						Type: ContentTypeImageURL,
 						ImageUrl: MessageImageUrl{
 							Url:    url,
@@ -188,7 +204,7 @@ func (m Message) ParseContent() []MediaMessage {
 				}
 			case ContentTypeInputAudio:
 				if subObj, ok := contentMap["input_audio"].(map[string]any); ok {
-					contentList = append(contentList, MediaMessage{
+					contentList = append(contentList, MediaContent{
 						Type: ContentTypeInputAudio,
 						InputAudio: MessageInputAudio{
 							Data:   subObj["data"].(string),

+ 2 - 9
relay/channel/claude/relay-claude.go

@@ -240,14 +240,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
 					claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage)
 				}
 				if message.ToolCalls != nil {
-					for _, tc := range message.ToolCalls.([]interface{}) {
-						toolCallJSON, _ := json.Marshal(tc)
-						var toolCall dto.ToolCall
-						err := json.Unmarshal(toolCallJSON, &toolCall)
-						if err != nil {
-							common.SysError("tool call is not a dto.ToolCall: " + fmt.Sprintf("%v", tc))
-							continue
-						}
+					for _, toolCall := range message.ParseToolCalls() {
 						inputObj := make(map[string]any)
 						if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &inputObj); err != nil {
 							common.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments))
@@ -393,7 +386,7 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
 	}
 	choice.SetStringContent(responseText)
 	if len(tools) > 0 {
-		choice.Message.ToolCalls = tools
+		choice.Message.SetToolCalls(tools)
 	}
 	fullTextResponse.Model = claudeResponse.Model
 	choices = append(choices, choice)

+ 50 - 55
relay/channel/gemini/relay-gemini.go

@@ -108,50 +108,63 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
 				},
 			}
 			continue
+		} else if message.Role == "tool" {
+			message.Role = "model"
 		}
+
+		var parts []GeminiPart
 		content := GeminiChatContent{
 			Role: message.Role,
-			//Parts: []GeminiPart{
-			//	{
-			//		Text: message.StringContent(),
-			//	},
-			//},
 		}
-		openaiContent := message.ParseContent()
-		var parts []GeminiPart
-		imageNum := 0
-		for _, part := range openaiContent {
-			if part.Type == dto.ContentTypeText {
-				parts = append(parts, GeminiPart{
-					Text: part.Text,
-				})
-			} else if part.Type == dto.ContentTypeImageURL {
-				imageNum += 1
-
-				if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum {
-					return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum)
+		isToolCall := false
+		if message.ToolCalls != nil {
+			isToolCall = true
+			for _, call := range message.ParseToolCalls() {
+				toolCall := GeminiPart{
+					FunctionCall: &FunctionCall{
+						FunctionName: call.Function.Name,
+						Arguments:    call.Function.Parameters,
+					},
 				}
-				// 判断是否是url
-				if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") {
-					// 是url,获取图片的类型和base64编码的数据
-					mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
+				parts = append(parts, toolCall)
+			}
+		}
+		if !isToolCall {
+			openaiContent := message.ParseContent()
+			imageNum := 0
+			for _, part := range openaiContent {
+				if part.Type == dto.ContentTypeText {
 					parts = append(parts, GeminiPart{
-						InlineData: &GeminiInlineData{
-							MimeType: mimeType,
-							Data:     data,
-						},
+						Text: part.Text,
 					})
-				} else {
-					_, format, base64String, err := service.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url)
-					if err != nil {
-						return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
+				} else if part.Type == dto.ContentTypeImageURL {
+					imageNum += 1
+
+					if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum {
+						return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum)
+					}
+					// 判断是否是url
+					if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") {
+						// 是url,获取图片的类型和base64编码的数据
+						mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
+						parts = append(parts, GeminiPart{
+							InlineData: &GeminiInlineData{
+								MimeType: mimeType,
+								Data:     data,
+							},
+						})
+					} else {
+						_, format, base64String, err := service.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url)
+						if err != nil {
+							return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
+						}
+						parts = append(parts, GeminiPart{
+							InlineData: &GeminiInlineData{
+								MimeType: "image/" + format,
+								Data:     base64String,
+							},
+						})
 					}
-					parts = append(parts, GeminiPart{
-						InlineData: &GeminiInlineData{
-							MimeType: "image/" + format,
-							Data:     base64String,
-						},
-					})
 				}
 			}
 		}
@@ -161,25 +174,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
 		if content.Role == "assistant" {
 			content.Role = "model"
 		}
-		// Converting system prompt to prompt from user for the same reason
-		//if content.Role == "system" {
-		//	content.Role = "user"
-		//	shouldAddDummyModelMessage = true
-		//}
 		geminiRequest.Contents = append(geminiRequest.Contents, content)
-		//
-		//// If a system message is the last message, we need to add a dummy model message to make gemini happy
-		//if shouldAddDummyModelMessage {
-		//	geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
-		//		Role: "model",
-		//		Parts: []GeminiPart{
-		//			{
-		//				Text: "Okay",
-		//			},
-		//		},
-		//	})
-		//	shouldAddDummyModelMessage = false
-		//}
 	}
 	return &geminiRequest, nil
 }
@@ -278,7 +273,7 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
 		if len(candidate.Content.Parts) > 0 {
 			if candidate.Content.Parts[0].FunctionCall != nil {
 				choice.FinishReason = constant.FinishReasonToolCalls
-				choice.Message.ToolCalls = getToolCalls(&candidate)
+				choice.Message.SetToolCalls(getToolCalls(&candidate))
 			} else {
 				var texts []string
 				for _, part := range candidate.Content.Parts {