Răsfoiți Sursa

fix: mutil func call in gemini

Yan 1 an în urmă
părinte
comite
a4795737fe
3 a modificat fișierele cu 113 adăugiri și 75 ștergeri
  1. 1 0
      .gitignore
  2. 3 1
      common/str.go
  3. 109 74
      relay/channel/gemini/relay-gemini.go

+ 1 - 0
.gitignore

@@ -8,3 +8,4 @@ build
 logs
 web/dist
 .env
+one-api

+ 3 - 1
common/str.go

@@ -35,7 +35,9 @@ func StrToMap(str string) map[string]interface{} {
 	m := make(map[string]interface{})
 	err := json.Unmarshal([]byte(str), &m)
 	if err != nil {
-		return nil
+		return map[string]interface{}{
+			"result": str,
+		}
 	}
 	return m
 }

+ 109 - 74
relay/channel/gemini/relay-gemini.go

@@ -95,7 +95,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
 			geminiRequest.GenerationConfig.ResponseSchema = cleanedSchema
 		}
 	}
-
+	tool_call_ids := make(map[string]string)
 	//shouldAddDummyModelMessage := false
 	for _, message := range textRequest.Messages {
 
@@ -108,6 +108,27 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
 				},
 			}
 			continue
+		} else if message.Role == "tool" {
+			if len(geminiRequest.Contents) == 0 || geminiRequest.Contents[len(geminiRequest.Contents)-1].Role != "user" {
+				geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
+					Role: "user",
+				})
+			}
+			var parts = &geminiRequest.Contents[len(geminiRequest.Contents)-1].Parts
+			name := ""
+			if message.Name != nil {
+				name = *message.Name
+			} else if val, exists := tool_call_ids[message.ToolCallId]; exists {
+				name = val
+			}
+			functionResp := &FunctionResponse{
+				Name:     name,
+				Response: common.StrToMap(message.StringContent()),
+			}
+			*parts = append(*parts, GeminiPart{
+				FunctionResponse: functionResp,
+			})
+			continue
 		}
 		var parts []GeminiPart
 		content := GeminiChatContent{
@@ -125,62 +146,49 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
 					},
 				}
 				parts = append(parts, toolCall)
+				tool_call_ids[call.ID] = call.Function.Name
 			}
 		}
 		if !isToolCall {
-			if message.Role == "tool" {
-				content.Role = "user"
-				name := ""
-				if message.Name != nil {
-					name = *message.Name
-				}
-				functionResp := &FunctionResponse{
-					Name:     name,
-					Response: common.StrToMap(message.StringContent()),
-				}
-				parts = append(parts, GeminiPart{
-					FunctionResponse: functionResp,
-				})
-			} else {
-				openaiContent := message.ParseContent()
-				imageNum := 0
-				for _, part := range openaiContent {
-					if part.Type == dto.ContentTypeText {
+			openaiContent := message.ParseContent()
+			imageNum := 0
+			for _, part := range openaiContent {
+				if part.Type == dto.ContentTypeText {
+					parts = append(parts, GeminiPart{
+						Text: part.Text,
+					})
+				} else if part.Type == dto.ContentTypeImageURL {
+					imageNum += 1
+
+					if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum {
+						return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum)
+					}
+					// 判断是否是url
+					if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") {
+						// 是url,获取图片的类型和base64编码的数据
+						mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
 						parts = append(parts, GeminiPart{
-							Text: part.Text,
+							InlineData: &GeminiInlineData{
+								MimeType: mimeType,
+								Data:     data,
+							},
 						})
-					} else if part.Type == dto.ContentTypeImageURL {
-						imageNum += 1
-
-						if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum {
-							return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum)
-						}
-						// 判断是否是url
-						if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") {
-							// 是url,获取图片的类型和base64编码的数据
-							mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
-							parts = append(parts, GeminiPart{
-								InlineData: &GeminiInlineData{
-									MimeType: mimeType,
-									Data:     data,
-								},
-							})
-						} else {
-							_, format, base64String, err := service.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url)
-							if err != nil {
-								return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
-							}
-							parts = append(parts, GeminiPart{
-								InlineData: &GeminiInlineData{
-									MimeType: "image/" + format,
-									Data:     base64String,
-								},
-							})
+					} else {
+						_, format, base64String, err := service.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url)
+						if err != nil {
+							return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
 						}
+						parts = append(parts, GeminiPart{
+							InlineData: &GeminiInlineData{
+								MimeType: "image/" + format,
+								Data:     base64String,
+							},
+						})
 					}
 				}
 			}
 		}
+
 		content.Parts = parts
 
 		// there's no assistant role in gemini and API shall vomit if Role is not user or model
@@ -242,19 +250,13 @@ func (g *GeminiChatResponse) GetResponseText() string {
 	return ""
 }
 
-func getToolCalls(candidate *GeminiChatCandidate) []dto.ToolCall {
-	var toolCalls []dto.ToolCall
-
-	item := candidate.Content.Parts[0]
-	if item.FunctionCall == nil {
-		return toolCalls
-	}
+func getToolCall(item *GeminiPart) *dto.ToolCall {
 	argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
 	if err != nil {
-		//common.SysError("getToolCalls failed: " + err.Error())
-		return toolCalls
+		//common.SysError("getToolCall failed: " + err.Error())
+		return nil
 	}
-	toolCall := dto.ToolCall{
+	return &dto.ToolCall{
 		ID:   fmt.Sprintf("call_%s", common.GetUUID()),
 		Type: "function",
 		Function: dto.FunctionCall{
@@ -262,10 +264,32 @@ func getToolCalls(candidate *GeminiChatCandidate) []dto.ToolCall {
 			Name:      item.FunctionCall.FunctionName,
 		},
 	}
-	toolCalls = append(toolCalls, toolCall)
-	return toolCalls
 }
 
+// func getToolCalls(candidate *GeminiChatCandidate, index int) []dto.ToolCall {
+// 	var toolCalls []dto.ToolCall
+
+// 	item := candidate.Content.Parts[index]
+// 	if item.FunctionCall == nil {
+// 		return toolCalls
+// 	}
+// 	argsBytes, err := json.Marshal(item.FunctionCall.Arguments)
+// 	if err != nil {
+// 		//common.SysError("getToolCalls failed: " + err.Error())
+// 		return toolCalls
+// 	}
+// 	toolCall := dto.ToolCall{
+// 		ID:   fmt.Sprintf("call_%s", common.GetUUID()),
+// 		Type: "function",
+// 		Function: dto.FunctionCall{
+// 			Arguments: string(argsBytes),
+// 			Name:      item.FunctionCall.FunctionName,
+// 		},
+// 	}
+// 	toolCalls = append(toolCalls, toolCall)
+// 	return toolCalls
+// }
+
 func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse {
 	fullTextResponse := dto.OpenAITextResponse{
 		Id:      fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
@@ -275,6 +299,8 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
 	}
 	content, _ := json.Marshal("")
 	for i, candidate := range response.Candidates {
+		// jsonData, _ := json.MarshalIndent(candidate, "", "  ")
+		// common.SysLog(fmt.Sprintf("candidate: %v", string(jsonData)))
 		choice := dto.OpenAITextResponseChoice{
 			Index: i,
 			Message: dto.Message{
@@ -284,16 +310,20 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp
 			FinishReason: constant.FinishReasonStop,
 		}
 		if len(candidate.Content.Parts) > 0 {
-			if candidate.Content.Parts[0].FunctionCall != nil {
-				choice.FinishReason = constant.FinishReasonToolCalls
-				choice.Message.SetToolCalls(getToolCalls(&candidate))
-			} else {
-				var texts []string
-				for _, part := range candidate.Content.Parts {
+			var texts []string
+			var tool_calls []dto.ToolCall
+			for _, part := range candidate.Content.Parts {
+				if part.FunctionCall != nil {
+					choice.FinishReason = constant.FinishReasonToolCalls
+					if call := getToolCall(&part); call != nil {
+						tool_calls = append(tool_calls, *call)
+					}
+				} else {
 					texts = append(texts, part.Text)
 				}
-				choice.Message.SetStringContent(strings.Join(texts, "\n"))
 			}
+			choice.Message.SetStringContent(strings.Join(texts, "\n"))
+			choice.Message.SetToolCalls(tool_calls)
 		}
 		fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
 	}
@@ -304,18 +334,23 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.Ch
 	var choice dto.ChatCompletionsStreamResponseChoice
 	//choice.Delta.SetContentString(geminiResponse.GetResponseText())
 	if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 {
-		respFirstParts := geminiResponse.Candidates[0].Content.Parts
-		if respFirstParts[0].FunctionCall != nil {
-			// function response
-			choice.Delta.ToolCalls = getToolCalls(&geminiResponse.Candidates[0])
-		} else {
-			// text response
-			var texts []string
-			for _, part := range respFirstParts {
+		var texts []string
+		var tool_calls []dto.ToolCall
+		for _, part := range geminiResponse.Candidates[0].Content.Parts {
+			if part.FunctionCall != nil {
+				if call := getToolCall(&part); call != nil {
+					tool_calls = append(tool_calls, *call)
+				}
+			} else {
 				texts = append(texts, part.Text)
 			}
+		}
+		if len(texts) > 0 {
 			choice.Delta.SetContentString(strings.Join(texts, "\n"))
 		}
+		if len(tool_calls) > 0 {
+			choice.Delta.ToolCalls = tool_calls
+		}
 	}
 	var response dto.ChatCompletionsStreamResponse
 	response.Object = "chat.completion.chunk"