Quellcode durchsuchen

feat: 完善格式抓换,修复gemini渠道和openai渠道在claude code中使用的问题

CaIon vor 4 Monaten
Ursprung
Commit
daa7a13505
6 geänderte Dateien mit 98 neuen und 14 gelöschten Zeilen
  1. 3 0
      common/gin.go
  2. 12 0
      dto/claude.go
  3. 20 0
      dto/openai_response.go
  4. 37 10
      relay/channel/gemini/relay-gemini.go
  5. 16 0
      relay/channel/openai/adaptor.go
  6. 10 4
      service/convert.go

+ 3 - 0
common/gin.go

@@ -31,6 +31,9 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
 	if err != nil {
 		return err
 	}
+	//if DebugEnabled {
+	//	println("UnmarshalBodyReusable request body:", string(requestBody))
+	//}
 	contentType := c.Request.Header.Get("Content-Type")
 	if strings.HasPrefix(contentType, "application/json") {
 		err = Unmarshal(requestBody, &v)

+ 12 - 0
dto/claude.go

@@ -199,6 +199,18 @@ type ClaudeRequest struct {
 	Thinking   *Thinking `json:"thinking,omitempty"`
 }
 
+func (c *ClaudeRequest) SearchToolNameByToolCallId(toolCallId string) string {
+	for _, message := range c.Messages {
+		content, _ := message.ParseContent()
+		for _, mediaMessage := range content {
+			if mediaMessage.Id == toolCallId {
+				return mediaMessage.Name
+			}
+		}
+	}
+	return ""
+}
+
 // AddTool 添加工具到请求中
 func (c *ClaudeRequest) AddTool(tool any) {
 	if c.Tools == nil {

+ 20 - 0
dto/openai_response.go

@@ -143,6 +143,13 @@ type ChatCompletionsStreamResponse struct {
 	Usage             *Usage                                `json:"usage"`
 }
 
+func (c *ChatCompletionsStreamResponse) IsFinished() bool {
+	if len(c.Choices) == 0 {
+		return false
+	}
+	return c.Choices[0].FinishReason != nil && *c.Choices[0].FinishReason != ""
+}
+
 func (c *ChatCompletionsStreamResponse) IsToolCall() bool {
 	if len(c.Choices) == 0 {
 		return false
@@ -157,6 +164,19 @@ func (c *ChatCompletionsStreamResponse) GetFirstToolCall() *ToolCallResponse {
 	return nil
 }
 
+func (c *ChatCompletionsStreamResponse) ClearToolCalls() {
+	if !c.IsToolCall() {
+		return
+	}
+	for choiceIdx := range c.Choices {
+		for callIdx := range c.Choices[choiceIdx].Delta.ToolCalls {
+			c.Choices[choiceIdx].Delta.ToolCalls[callIdx].ID = ""
+			c.Choices[choiceIdx].Delta.ToolCalls[callIdx].Type = nil
+			c.Choices[choiceIdx].Delta.ToolCalls[callIdx].Function.Name = ""
+		}
+	}
+}
+
 func (c *ChatCompletionsStreamResponse) Copy() *ChatCompletionsStreamResponse {
 	choices := make([]ChatCompletionsStreamResponseChoice, len(c.Choices))
 	copy(choices, c.Choices)

+ 37 - 10
relay/channel/gemini/relay-gemini.go

@@ -835,6 +835,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *dto.GeminiChatResponse) (*d
 					call.SetIndex(len(choice.Delta.ToolCalls))
 					choice.Delta.ToolCalls = append(choice.Delta.ToolCalls, *call)
 				}
+
 			} else if part.Thought {
 				isThought = true
 				texts = append(texts, part.Text)
@@ -895,6 +896,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
 	responseText := strings.Builder{}
 	var usage = &dto.Usage{}
 	var imageCount int
+	finishReason := constant.FinishReasonStop
 
 	helper.StreamScannerHandler(c, resp, info, func(data string) bool {
 		var geminiResponse dto.GeminiChatResponse
@@ -936,9 +938,21 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
 
 		if info.SendResponseCount == 0 {
 			// send first response
-			err = handleStream(c, info, helper.GenerateStartEmptyResponse(id, createAt, info.UpstreamModelName, nil))
-			if err != nil {
-				common.LogError(c, err.Error())
+			emptyResponse := helper.GenerateStartEmptyResponse(id, createAt, info.UpstreamModelName, nil)
+			if response.IsToolCall() {
+				emptyResponse.Choices[0].Delta.ToolCalls = make([]dto.ToolCallResponse, 1)
+				emptyResponse.Choices[0].Delta.ToolCalls[0] = *response.GetFirstToolCall()
+				emptyResponse.Choices[0].Delta.ToolCalls[0].Function.Arguments = ""
+				finishReason = constant.FinishReasonToolCalls
+				err = handleStream(c, info, emptyResponse)
+				if err != nil {
+					common.LogError(c, err.Error())
+				}
+
+				response.ClearToolCalls()
+				if response.IsFinished() {
+					response.Choices[0].FinishReason = nil
+				}
 			}
 		}
 
@@ -947,7 +961,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
 			common.LogError(c, err.Error())
 		}
 		if isStop {
-			_ = handleStream(c, info, helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop))
+			_ = handleStream(c, info, helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, finishReason))
 		}
 		return true
 	})
@@ -1026,13 +1040,26 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
 	}
 
 	fullTextResponse.Usage = usage
-	jsonResponse, err := json.Marshal(fullTextResponse)
-	if err != nil {
-		return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+
+	switch info.RelayFormat {
+	case relaycommon.RelayFormatOpenAI:
+		responseBody, err = common.Marshal(fullTextResponse)
+		if err != nil {
+			return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+		}
+	case relaycommon.RelayFormatClaude:
+		claudeResp := service.ResponseOpenAI2Claude(fullTextResponse, info)
+		claudeRespStr, err := common.Marshal(claudeResp)
+		if err != nil {
+			return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+		}
+		responseBody = claudeRespStr
+	case relaycommon.RelayFormatGemini:
+		break
 	}
-	c.Writer.Header().Set("Content-Type", "application/json")
-	c.Writer.WriteHeader(resp.StatusCode)
-	c.Writer.Write(jsonResponse)
+
+	common.IOCopyBytesGracefully(c, resp, responseBody)
+
 	return &usage, nil
 }
 

+ 16 - 0
relay/channel/openai/adaptor.go

@@ -63,10 +63,26 @@ func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayIn
 	//if !strings.Contains(request.Model, "claude") {
 	//	return nil, fmt.Errorf("you are using openai channel type with path /v1/messages, only claude model supported convert, but got %s", request.Model)
 	//}
+	//if common.DebugEnabled {
+	//	bodyBytes := []byte(common.GetJsonString(request))
+	//	err := os.WriteFile(fmt.Sprintf("claude_request_%s.txt", c.GetString(common.RequestIdKey)), bodyBytes, 0644)
+	//	if err != nil {
+	//		println(fmt.Sprintf("failed to save request body to file: %v", err))
+	//	}
+	//}
 	aiRequest, err := service.ClaudeToOpenAIRequest(*request, info)
 	if err != nil {
 		return nil, err
 	}
+	//if common.DebugEnabled {
+	//	println(fmt.Sprintf("convert claude to openai request result: %s", common.GetJsonString(aiRequest)))
+	//	// Save request body to file for debugging
+	//	bodyBytes := []byte(common.GetJsonString(aiRequest))
+	//	err = os.WriteFile(fmt.Sprintf("claude_to_openai_request_%s.txt", c.GetString(common.RequestIdKey)), bodyBytes, 0644)
+	//	if err != nil {
+	//		println(fmt.Sprintf("failed to save request body to file: %v", err))
+	//	}
+	//}
 	if info.SupportStreamOptions && info.IsStream {
 		aiRequest.StreamOptions = &dto.StreamOptions{
 			IncludeUsage: true,

+ 10 - 4
service/convert.go

@@ -153,9 +153,13 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest, info *relaycommon.Re
 					toolCalls = append(toolCalls, toolCall)
 				case "tool_result":
 					// Add tool result as a separate message
+					toolName := mediaMsg.Name
+					if toolName == "" {
+						toolName = claudeRequest.SearchToolNameByToolCallId(mediaMsg.ToolUseId)
+					}
 					oaiToolMessage := dto.Message{
 						Role:       "tool",
-						Name:       &mediaMsg.Name,
+						Name:       &toolName,
 						ToolCallId: mediaMsg.ToolUseId,
 					}
 					//oaiToolMessage.SetStringContent(*mediaMsg.GetMediaContent().Text)
@@ -218,12 +222,14 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
 		//	Type: "ping",
 		//})
 		if openAIResponse.IsToolCall() {
+			info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeTools
 			resp := &dto.ClaudeResponse{
 				Type: "content_block_start",
 				ContentBlock: &dto.ClaudeMediaMessage{
-					Id:   openAIResponse.GetFirstToolCall().ID,
-					Type: "tool_use",
-					Name: openAIResponse.GetFirstToolCall().Function.Name,
+					Id:    openAIResponse.GetFirstToolCall().ID,
+					Type:  "tool_use",
+					Name:  openAIResponse.GetFirstToolCall().Function.Name,
+					Input: map[string]interface{}{},
 				},
 			}
 			resp.SetIndex(0)