Parcourir la source

fix: claude to openai convert logic (#416)

* fix: claude to openai convert logic

* fix: ci lint
zijiren il y a 1 mois
Parent
commit
eda72c1ff4
1 fichiers modifiés avec 97 ajouts et 102 suppressions
  1. 97 102
      core/relay/adaptor/openai/claude.go

+ 97 - 102
core/relay/adaptor/openai/claude.go

@@ -119,127 +119,122 @@ func convertClaudeMessagesToOpenAI(
 
 	// Convert regular messages
 	for _, msg := range claudeRequest.Messages {
-		// Check if this is a user message with tool results - handle specially
-		if msg.Role == "user" {
-			content, ok := msg.Content.([]any)
-
-			hasToolResults := false
-			if ok {
-				rawBytes, _ := sonic.Marshal(content)
-
-				var contentArray []relaymodel.ClaudeContent
-
-				_ = sonic.Unmarshal(rawBytes, &contentArray)
-
-				// First check if there are any tool_result blocks
-				var regularContent []relaymodel.MessageContent
-				for _, content := range contentArray {
-					switch content.Type {
-					case "tool_result":
-						hasToolResults = true
-						// Create a separate tool message for each tool_result
-						toolMsg := relaymodel.Message{
-							Role:       "tool",
-							Content:    content.Content,
-							ToolCallID: content.ToolUseID,
-						}
-						messages = append(messages, toolMsg)
-					case "text":
-						// Collect non-tool_result content
-						regularContent = append(regularContent, relaymodel.MessageContent{
-							Type: relaymodel.ContentTypeText,
-							Text: content.Text,
-						})
-					}
-				}
+		openAIMsg := relaymodel.Message{
+			Role: msg.Role,
+		}
 
-				// If there were tool results and also regular content, add the regular content as a user message
-				if hasToolResults {
-					if len(regularContent) > 0 {
-						messages = append(messages, relaymodel.Message{
-							Role:    "user",
-							Content: regularContent,
-						})
-					}
+		result := convertClaudeContent(msg.Content)
+		messages = append(messages, result.Messages...)
+		openAIMsg.ToolCalls = result.ToolCalls
 
-					continue // Skip the normal message processing
-				}
-			}
+		openAIMsg.Content = result.Content
+		if openAIMsg.Content != nil {
+			messages = append(messages, openAIMsg)
 		}
+	}
 
-		// Regular message processing
-		openAIMsg := relaymodel.Message{
-			Role: msg.Role,
-		}
+	return messages
+}
 
-		switch content := msg.Content.(type) {
-		case string:
-			openAIMsg.Content = content
-		case []any:
-			rawBytes, _ := sonic.Marshal(content)
+type convertClaudeContentResult struct {
+	Content   any
+	ToolCalls []relaymodel.ToolCall
+	Messages  []relaymodel.Message
+}
 
-			var contentArray []relaymodel.ClaudeContent
+func convertClaudeContent(content any) convertClaudeContentResult {
+	result := convertClaudeContentResult{}
+	switch content := content.(type) {
+	case string:
+		result.Content = content
+	case []any:
+		rawBytes, _ := sonic.Marshal(content)
 
-			_ = sonic.Unmarshal(rawBytes, &contentArray)
+		var contentArray []relaymodel.ClaudeContent
 
-			var parts []relaymodel.MessageContent
-			for _, content := range contentArray {
-				switch content.Type {
-				case "text":
-					parts = append(parts, relaymodel.MessageContent{
-						Type: relaymodel.ContentTypeText,
-						Text: content.Text,
-					})
-				case "thinking":
-					parts = append(parts, relaymodel.MessageContent{
-						Type: relaymodel.ContentTypeText,
-						Text: content.Thinking,
-					})
-				case "image":
-					if content.Source != nil {
-						imageURL := relaymodel.ImageURL{}
-						switch content.Source.Type {
-						case "url":
-							imageURL.URL = content.Source.URL
-						case "base64":
-							imageURL.URL = fmt.Sprintf("data:%s;base64,%s",
-								content.Source.MediaType, content.Source.Data)
-						}
+		_ = sonic.Unmarshal(rawBytes, &contentArray)
 
-						parts = append(parts, relaymodel.MessageContent{
-							Type:     relaymodel.ContentTypeImageURL,
-							ImageURL: &imageURL,
-						})
-					}
-				case "tool_use":
-					// Handle tool calls
-					if openAIMsg.ToolCalls == nil {
-						openAIMsg.ToolCalls = []relaymodel.ToolCall{}
+		var parts []relaymodel.MessageContent
+		for _, content := range contentArray {
+			switch content.Type {
+			case "text":
+				text := strings.TrimSpace(content.Text)
+				if text == "" {
+					continue
+				}
+
+				parts = append(parts, relaymodel.MessageContent{
+					Type: relaymodel.ContentTypeText,
+					Text: text,
+				})
+			case "thinking":
+				text := strings.TrimSpace(content.Thinking)
+				if text == "" {
+					continue
+				}
+
+				parts = append(parts, relaymodel.MessageContent{
+					Type: relaymodel.ContentTypeText,
+					Text: text,
+				})
+			case "image":
+				if content.Source != nil {
+					imageURL := relaymodel.ImageURL{}
+					switch content.Source.Type {
+					case "url":
+						imageURL.URL = content.Source.URL
+					case "base64":
+						imageURL.URL = fmt.Sprintf("data:%s;base64,%s",
+							content.Source.MediaType, content.Source.Data)
 					}
 
-					args, _ := sonic.MarshalString(content.Input)
-					openAIMsg.ToolCalls = append(openAIMsg.ToolCalls, relaymodel.ToolCall{
-						ID:   content.ID,
-						Type: "function",
-						Function: relaymodel.Function{
-							Name:      content.Name,
-							Arguments: args,
-						},
+					parts = append(parts, relaymodel.MessageContent{
+						Type:     relaymodel.ContentTypeImageURL,
+						ImageURL: &imageURL,
 					})
-				default:
-					continue
 				}
-			}
+			case "tool_use":
+				// Handle tool calls
+				args, _ := sonic.MarshalString(content.Input)
+				result.ToolCalls = append(result.ToolCalls, relaymodel.ToolCall{
+					ID:   content.ID,
+					Type: "function",
+					Function: relaymodel.Function{
+						Name:      content.Name,
+						Arguments: args,
+					},
+				})
+			case "tool_result":
+				// Create a separate tool message for each tool_result
+				var newContent any
+				switch v := content.Content.(type) {
+				case string:
+					newContent = v
+				case []any:
+					result := convertClaudeContent(v)
+					newContent = result.Content
+				}
 
-			if len(parts) > 0 {
-				openAIMsg.Content = parts
+				toolMsg := relaymodel.Message{
+					Role:       "tool",
+					Content:    newContent,
+					ToolCallID: content.ToolUseID,
+				}
+
+				result.Messages = append(result.Messages, toolMsg)
+
+				continue
+			default:
+				continue
 			}
 		}
 
-		messages = append(messages, openAIMsg)
+		if len(parts) > 0 {
+			result.Content = parts
+		}
 	}
 
-	return messages
+	return result
 }
 
 // convertClaudeToolsToOpenAI converts Claude tools to OpenAI format