Sfoglia il codice sorgente

feat: 完善openai转claude支持

CaIon 8 mesi fa
parent
commit
4b3e30e669

+ 6 - 1
dto/claude.go

@@ -7,7 +7,7 @@ type ClaudeMetadata struct {
 }
 }
 
 
 type ClaudeMediaMessage struct {
 type ClaudeMediaMessage struct {
-	Type        string               `json:"type"`
+	Type        string               `json:"type,omitempty"`
 	Text        *string              `json:"text,omitempty"`
 	Text        *string              `json:"text,omitempty"`
 	Model       string               `json:"model,omitempty"`
 	Model       string               `json:"model,omitempty"`
 	Source      *ClaudeMessageSource `json:"source,omitempty"`
 	Source      *ClaudeMessageSource `json:"source,omitempty"`
@@ -50,6 +50,11 @@ func (c *ClaudeMediaMessage) GetStringContent() string {
 	return ""
 	return ""
 }
 }
 
 
+func (c *ClaudeMediaMessage) GetJsonRowString() string {
+	jsonContent, _ := json.Marshal(c)
+	return string(jsonContent)
+}
+
 func (c *ClaudeMediaMessage) SetContent(content any) {
 func (c *ClaudeMediaMessage) SetContent(content any) {
 	jsonContent, _ := json.Marshal(content)
 	jsonContent, _ := json.Marshal(content)
 	c.Content = jsonContent
 	c.Content = jsonContent

+ 6 - 0
dto/openai_request.go

@@ -214,6 +214,12 @@ func (m *Message) StringContent() string {
 	return stringContent
 	return stringContent
 }
 }
 
 
+func (m *Message) SetNullContent() {
+	m.Content = nil
+	m.parsedStringContent = nil
+	m.parsedContent = nil
+}
+
 func (m *Message) SetStringContent(content string) {
 func (m *Message) SetStringContent(content string) {
 	jsonContent, _ := json.Marshal(content)
 	jsonContent, _ := json.Marshal(content)
 	m.Content = jsonContent
 	m.Content = jsonContent

+ 5 - 3
relay/channel/openai/helper.go

@@ -31,6 +31,9 @@ func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo
 		return err
 		return err
 	}
 	}
 
 
+	if streamResponse.Usage != nil {
+		info.ClaudeConvertInfo.Usage = streamResponse.Usage
+	}
 	claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
 	claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
 	for _, resp := range claudeResponses {
 	for _, resp := range claudeResponses {
 		helper.ClaudeData(c, *resp)
 		helper.ClaudeData(c, *resp)
@@ -170,15 +173,14 @@ func handleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
 		helper.Done(c)
 		helper.Done(c)
 
 
 	case relaycommon.RelayFormatClaude:
 	case relaycommon.RelayFormatClaude:
+		info.ClaudeConvertInfo.Done = true
 		var streamResponse dto.ChatCompletionsStreamResponse
 		var streamResponse dto.ChatCompletionsStreamResponse
 		if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
 		if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
 			common.SysError("error unmarshalling stream response: " + err.Error())
 			common.SysError("error unmarshalling stream response: " + err.Error())
 			return
 			return
 		}
 		}
 
 
-		if !containStreamUsage {
-			streamResponse.Usage = usage
-		}
+		info.ClaudeConvertInfo.Usage = usage
 
 
 		claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
 		claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
 		for _, resp := range claudeResponses {
 		for _, resp := range claudeResponses {

+ 2 - 0
relay/channel/openai/relay-openai.go

@@ -170,8 +170,10 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 			}
 			}
 		}
 		}
 	}
 	}
+
 	if shouldSendLastResp {
 	if shouldSendLastResp {
 		sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
 		sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
+		//err = handleStreamFormat(c, info, lastStreamData, forceFormat, thinkToContent)
 	}
 	}
 
 
 	// 处理token计算
 	// 处理token计算

+ 8 - 3
relay/common/relay_info.go

@@ -19,13 +19,18 @@ type ThinkingContentInfo struct {
 }
 }
 
 
 const (
 const (
-	LastMessageTypeText  = "text"
-	LastMessageTypeTools = "tools"
+	LastMessageTypeNone     = "none"
+	LastMessageTypeText     = "text"
+	LastMessageTypeTools    = "tools"
+	LastMessageTypeThinking = "thinking"
 )
 )
 
 
 type ClaudeConvertInfo struct {
 type ClaudeConvertInfo struct {
 	LastMessagesType string
 	LastMessagesType string
 	Index            int
 	Index            int
+	Usage            *dto.Usage
+	FinishReason     string
+	Done             bool
 }
 }
 
 
 const (
 const (
@@ -113,7 +118,7 @@ func GenRelayInfoClaude(c *gin.Context) *RelayInfo {
 	info.RelayFormat = RelayFormatClaude
 	info.RelayFormat = RelayFormatClaude
 	info.ShouldIncludeUsage = false
 	info.ShouldIncludeUsage = false
 	info.ClaudeConvertInfo = ClaudeConvertInfo{
 	info.ClaudeConvertInfo = ClaudeConvertInfo{
-		LastMessagesType: LastMessageTypeText,
+		LastMessagesType: LastMessageTypeNone,
 	}
 	}
 	return info
 	return info
 }
 }

+ 73 - 30
service/convert.go

@@ -45,7 +45,7 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest) (*dto.GeneralOpenAIR
 
 
 	// Add system message if present
 	// Add system message if present
 	if claudeRequest.System != nil {
 	if claudeRequest.System != nil {
-		if claudeRequest.IsStringSystem() {
+		if claudeRequest.IsStringSystem() && claudeRequest.GetStringSystem() != "" {
 			openAIMessage := dto.Message{
 			openAIMessage := dto.Message{
 				Role: "system",
 				Role: "system",
 			}
 			}
@@ -122,23 +122,22 @@ func ClaudeToOpenAIRequest(claudeRequest dto.ClaudeRequest) (*dto.GeneralOpenAIR
 						oaiToolMessage.SetStringContent(mediaMsg.GetStringContent())
 						oaiToolMessage.SetStringContent(mediaMsg.GetStringContent())
 					} else {
 					} else {
 						mediaContents := mediaMsg.ParseMediaContent()
 						mediaContents := mediaMsg.ParseMediaContent()
-						if len(mediaContents) > 0 && mediaContents[0].Text != nil {
-							oaiToolMessage.SetStringContent(*mediaContents[0].Text)
-						}
+						encodeJson, _ := common.EncodeJson(mediaContents)
+						oaiToolMessage.SetStringContent(string(encodeJson))
 					}
 					}
 					openAIMessages = append(openAIMessages, oaiToolMessage)
 					openAIMessages = append(openAIMessages, oaiToolMessage)
 				}
 				}
 			}
 			}
 
 
-			if len(mediaMessages) > 0 {
-				openAIMessage.SetMediaContent(mediaMessages)
-			}
-
 			if len(toolCalls) > 0 {
 			if len(toolCalls) > 0 {
 				openAIMessage.SetToolCalls(toolCalls)
 				openAIMessage.SetToolCalls(toolCalls)
 			}
 			}
+
+			if len(mediaMessages) > 0 && len(toolCalls) == 0 {
+				openAIMessage.SetMediaContent(mediaMessages)
+			}
 		}
 		}
-		if len(openAIMessage.ParseContent()) > 0 {
+		if len(openAIMessage.ParseContent()) > 0 || len(openAIMessage.ToolCalls) > 0 {
 			openAIMessages = append(openAIMessages, openAIMessage)
 			openAIMessages = append(openAIMessages, openAIMessage)
 		}
 		}
 	}
 	}
@@ -211,15 +210,15 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
 			resp.SetIndex(0)
 			resp.SetIndex(0)
 			claudeResponses = append(claudeResponses, resp)
 			claudeResponses = append(claudeResponses, resp)
 		} else {
 		} else {
-			resp := &dto.ClaudeResponse{
-				Type: "content_block_start",
-				ContentBlock: &dto.ClaudeMediaMessage{
-					Type: "text",
-					Text: common.GetPointer[string](""),
-				},
-			}
-			resp.SetIndex(0)
-			claudeResponses = append(claudeResponses, resp)
+			//resp := &dto.ClaudeResponse{
+			//	Type: "content_block_start",
+			//	ContentBlock: &dto.ClaudeMediaMessage{
+			//		Type: "text",
+			//		Text: common.GetPointer[string](""),
+			//	},
+			//}
+			//resp.SetIndex(0)
+			//claudeResponses = append(claudeResponses, resp)
 		}
 		}
 		return claudeResponses
 		return claudeResponses
 	}
 	}
@@ -232,16 +231,20 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
 		chosenChoice := openAIResponse.Choices[0]
 		chosenChoice := openAIResponse.Choices[0]
 		if chosenChoice.FinishReason != nil && *chosenChoice.FinishReason != "" {
 		if chosenChoice.FinishReason != nil && *chosenChoice.FinishReason != "" {
 			// should be done
 			// should be done
+			info.FinishReason = *chosenChoice.FinishReason
+			return claudeResponses
+		}
+		if info.Done {
 			claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
 			claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
-			if openAIResponse.Usage != nil {
+			if info.ClaudeConvertInfo.Usage != nil {
 				claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
 				claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
 					Type: "message_delta",
 					Type: "message_delta",
 					Usage: &dto.ClaudeUsage{
 					Usage: &dto.ClaudeUsage{
-						InputTokens:  openAIResponse.Usage.PromptTokens,
-						OutputTokens: openAIResponse.Usage.CompletionTokens,
+						InputTokens:  info.ClaudeConvertInfo.Usage.PromptTokens,
+						OutputTokens: info.ClaudeConvertInfo.Usage.CompletionTokens,
 					},
 					},
 					Delta: &dto.ClaudeMediaMessage{
 					Delta: &dto.ClaudeMediaMessage{
-						StopReason: common.GetPointer[string](stopReasonOpenAI2Claude(*chosenChoice.FinishReason)),
+						StopReason: common.GetPointer[string](stopReasonOpenAI2Claude(info.FinishReason)),
 					},
 					},
 				})
 				})
 			}
 			}
@@ -250,10 +253,10 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
 			})
 			})
 		} else {
 		} else {
 			var claudeResponse dto.ClaudeResponse
 			var claudeResponse dto.ClaudeResponse
-			claudeResponse.SetIndex(0)
+			var isEmpty bool
 			claudeResponse.Type = "content_block_delta"
 			claudeResponse.Type = "content_block_delta"
 			if len(chosenChoice.Delta.ToolCalls) > 0 {
 			if len(chosenChoice.Delta.ToolCalls) > 0 {
-				if info.ClaudeConvertInfo.LastMessagesType == relaycommon.LastMessageTypeText {
+				if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeTools {
 					claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
 					claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
 					info.ClaudeConvertInfo.Index++
 					info.ClaudeConvertInfo.Index++
 					claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
 					claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
@@ -274,15 +277,55 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
 					PartialJson: &chosenChoice.Delta.ToolCalls[0].Function.Arguments,
 					PartialJson: &chosenChoice.Delta.ToolCalls[0].Function.Arguments,
 				}
 				}
 			} else {
 			} else {
-				info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeText
-				// text delta
-				claudeResponse.Delta = &dto.ClaudeMediaMessage{
-					Type: "text_delta",
-					Text: common.GetPointer[string](chosenChoice.Delta.GetContentString()),
+				reasoning := chosenChoice.Delta.GetReasoningContent()
+				textContent := chosenChoice.Delta.GetContentString()
+				if reasoning != "" || textContent != "" {
+					if reasoning != "" {
+						if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeThinking {
+							//info.ClaudeConvertInfo.Index++
+							claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
+								Index: &info.ClaudeConvertInfo.Index,
+								Type:  "content_block_start",
+								ContentBlock: &dto.ClaudeMediaMessage{
+									Type:     "thinking",
+									Thinking: "",
+								},
+							})
+						}
+						info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeThinking
+						// text delta
+						claudeResponse.Delta = &dto.ClaudeMediaMessage{
+							Type:     "thinking_delta",
+							Thinking: reasoning,
+						}
+					} else {
+						if info.ClaudeConvertInfo.LastMessagesType != relaycommon.LastMessageTypeText {
+							claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
+							info.ClaudeConvertInfo.Index++
+							claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
+								Index: &info.ClaudeConvertInfo.Index,
+								Type:  "content_block_start",
+								ContentBlock: &dto.ClaudeMediaMessage{
+									Type: "text",
+									Text: common.GetPointer[string](""),
+								},
+							})
+						}
+						info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeText
+						// text delta
+						claudeResponse.Delta = &dto.ClaudeMediaMessage{
+							Type: "text_delta",
+							Text: common.GetPointer[string](textContent),
+						}
+					}
+				} else {
+					isEmpty = true
 				}
 				}
 			}
 			}
 			claudeResponse.Index = &info.ClaudeConvertInfo.Index
 			claudeResponse.Index = &info.ClaudeConvertInfo.Index
-			claudeResponses = append(claudeResponses, &claudeResponse)
+			if !isEmpty {
+				claudeResponses = append(claudeResponses, &claudeResponse)
+			}
 		}
 		}
 	}
 	}