Ver código fonte

refactor: Streamline AWS and Claude response handling by consolidating logic and improving error management

[email protected] 9 meses atrás
pai
commit
2c81a5f0cc

+ 7 - 52
relay/channel/aws/relay-aws.go

@@ -1,21 +1,17 @@
 package aws
 
 import (
-	"bytes"
 	"encoding/json"
 	"fmt"
 	"github.com/gin-gonic/gin"
 	"github.com/pkg/errors"
-	"io"
 	"net/http"
 	"one-api/common"
 	"one-api/dto"
 	"one-api/relay/channel/claude"
 	relaycommon "one-api/relay/common"
-	"one-api/relay/helper"
 	"one-api/service"
 	"strings"
-	"time"
 
 	"github.com/aws/aws-sdk-go-v2/aws"
 	"github.com/aws/aws-sdk-go-v2/credentials"
@@ -143,7 +139,6 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 	stream := awsResp.GetStream()
 	defer stream.Close()
 
-	c.Writer.Header().Set("Content-Type", "text/event-stream")
 	claudeInfo := &claude.ClaudeResponseInfo{
 		ResponseId:   fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
 		Created:      common.GetTimestamp(),
@@ -151,63 +146,23 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 		ResponseText: strings.Builder{},
 		Usage:        &dto.Usage{},
 	}
-	isFirst := true
-	c.Stream(func(w io.Writer) bool {
-		event, ok := <-stream.Events()
-		if !ok {
-			return false
-		}
 
+	for event := range stream.Events() {
 		switch v := event.(type) {
 		case *types.ResponseStreamMemberChunk:
-			if isFirst {
-				isFirst = false
-				info.FirstResponseTime = time.Now()
-			}
-			claudeResponse := new(dto.ClaudeResponse)
-			err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResponse)
-			if err != nil {
-				common.SysError("error unmarshalling stream response: " + err.Error())
-				return false
-			}
-
-			response := claude.StreamResponseClaude2OpenAI(requestMode, claudeResponse)
-
-			if !claude.FormatClaudeResponseInfo(RequestModeMessage, claudeResponse, response, claudeInfo) {
-				return true
-			}
-
-			jsonStr, err := json.Marshal(response)
-			if err != nil {
-				common.SysError("error marshalling stream response: " + err.Error())
-				return true
-			}
-			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
-			return true
+			info.SetFirstResponseTime()
+			claude.HandleResponseData(c, info, claudeInfo, string(v.Value.Bytes), RequestModeMessage)
 		case *types.UnknownUnionMember:
 			fmt.Println("unknown tag:", v.Tag)
-			return false
+			return wrapErr(errors.New("unknown response type")), nil
 		default:
 			fmt.Println("union is nil or unknown type")
-			return false
+			return wrapErr(errors.New("nil or unknown response type")), nil
 		}
-	})
-
-	if claudeInfo.Usage.PromptTokens == 0 {
-		//上游出错
-	}
-	if claudeInfo.Usage.CompletionTokens == 0 {
-		claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
 	}
 
-	if info.ShouldIncludeUsage {
-		response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
-		err := helper.ObjectData(c, response)
-		if err != nil {
-			common.SysError("send final response failed: " + err.Error())
-		}
-	}
-	helper.Done(c)
+	claude.HandleFinalResponse(c, info, claudeInfo, RequestModeMessage)
+
 	if resp != nil {
 		err = resp.Body.Close()
 		if err != nil {

+ 67 - 70
relay/channel/claude/relay-claude.go

@@ -479,77 +479,41 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeRespons
 	return true
 }
 
-func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
-
-	if info.RelayFormat == relaycommon.RelayFormatOpenAI {
-		return toOpenAIStreamHandler(c, resp, info, requestMode)
+func HandleResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) bool {
+	var claudeResponse dto.ClaudeResponse
+	err := json.NewDecoder(bytes.NewReader(common.StringToByteSlice(data))).Decode(&claudeResponse)
+	if err != nil {
+		common.SysError("error unmarshalling stream response: " + err.Error())
+		return false
 	}
-
-	usage := &dto.Usage{}
-	responseText := strings.Builder{}
-
-	helper.StreamScannerHandler(c, resp, info, func(data string) bool {
-		var claudeResponse dto.ClaudeResponse
-		err := json.NewDecoder(bytes.NewReader(common.StringToByteSlice(data))).Decode(&claudeResponse)
-		if err != nil {
-			common.SysError("error unmarshalling stream response: " + err.Error())
-			return true
-		}
+	if info.RelayFormat == relaycommon.RelayFormatClaude {
 		if requestMode == RequestModeCompletion {
-			responseText.WriteString(claudeResponse.Completion)
+			claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
 		} else {
 			if claudeResponse.Type == "message_start" {
 				// message_start, 获取usage
 				info.UpstreamModelName = claudeResponse.Message.Model
-				usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
-				usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
-				usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
-				usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens
+				claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
+				claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
+				claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
+				claudeInfo.Usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens
 			} else if claudeResponse.Type == "content_block_delta" {
-				responseText.WriteString(claudeResponse.Delta.GetText())
+				claudeInfo.ResponseText.WriteString(claudeResponse.Delta.GetText())
 			} else if claudeResponse.Type == "message_delta" {
 				if claudeResponse.Usage.InputTokens > 0 {
 					// 不叠加,只取最新的
-					usage.PromptTokens = claudeResponse.Usage.InputTokens
+					claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
 				}
-				usage.CompletionTokens = claudeResponse.Usage.OutputTokens
-				usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
+				claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
+				claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeInfo.Usage.CompletionTokens
 			}
 		}
 		helper.ClaudeChunkData(c, claudeResponse, data)
-		return true
-	})
-
-	if requestMode == RequestModeCompletion {
-		usage, _ = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
-	} else {
-		// 说明流模式建立失败,可能为官方出错
-		if usage.PromptTokens == 0 {
-			//usage.PromptTokens = info.PromptTokens
-		}
-		if usage.CompletionTokens == 0 {
-			usage, _ = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, usage.PromptTokens)
-		}
-	}
-	return nil, usage
-}
-
-func toOpenAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
-	responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
-	claudeInfo := &ClaudeResponseInfo{
-		ResponseId:   responseId,
-		Created:      common.GetTimestamp(),
-		Model:        info.UpstreamModelName,
-		ResponseText: strings.Builder{},
-		Usage:        &dto.Usage{},
-	}
-
-	helper.StreamScannerHandler(c, resp, info, func(data string) bool {
-		var claudeResponse dto.ClaudeResponse
+	} else if info.RelayFormat == relaycommon.RelayFormatOpenAI {
 		err := json.NewDecoder(bytes.NewReader(common.StringToByteSlice(data))).Decode(&claudeResponse)
 		if err != nil {
 			common.SysError("error unmarshalling stream response: " + err.Error())
-			return true
+			return false
 		}
 
 		response := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
@@ -562,27 +526,60 @@ func toOpenAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
 		if err != nil {
 			common.LogError(c, "send_stream_response_failed: "+err.Error())
 		}
-		return true
-	})
+	}
+	return true
+}
 
-	if requestMode == RequestModeCompletion {
-		claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
-	} else {
-		if claudeInfo.Usage.PromptTokens == 0 {
-			//上游出错
+func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
+	if info.RelayFormat == relaycommon.RelayFormatClaude {
+		if requestMode == RequestModeCompletion {
+			claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
+		} else {
+			// 说明流模式建立失败,可能为官方出错
+			if claudeInfo.Usage.PromptTokens == 0 {
+				//usage.PromptTokens = info.PromptTokens
+			}
+			if claudeInfo.Usage.CompletionTokens == 0 {
+				claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
+			}
 		}
-		if claudeInfo.Usage.CompletionTokens == 0 {
-			claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
+	} else if info.RelayFormat == relaycommon.RelayFormatOpenAI {
+		if requestMode == RequestModeCompletion {
+			claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
+		} else {
+			if claudeInfo.Usage.PromptTokens == 0 {
+				//上游出错
+			}
+			if claudeInfo.Usage.CompletionTokens == 0 {
+				claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
+			}
 		}
-	}
-	if info.ShouldIncludeUsage {
-		response := helper.GenerateFinalUsageResponse(responseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
-		err := helper.ObjectData(c, response)
-		if err != nil {
-			common.SysError("send final response failed: " + err.Error())
+		if info.ShouldIncludeUsage {
+			response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
+			err := helper.ObjectData(c, response)
+			if err != nil {
+				common.SysError("send final response failed: " + err.Error())
+			}
 		}
+		helper.Done(c)
 	}
-	helper.Done(c)
+}
+
+func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+	claudeInfo := &ClaudeResponseInfo{
+		ResponseId:   fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
+		Created:      common.GetTimestamp(),
+		Model:        info.UpstreamModelName,
+		ResponseText: strings.Builder{},
+		Usage:        &dto.Usage{},
+	}
+
+	helper.StreamScannerHandler(c, resp, info, func(data string) bool {
+		return HandleResponseData(c, info, claudeInfo, data, requestMode)
+	})
+
+	HandleFinalResponse(c, info, claudeInfo, requestMode)
+
 	return nil, claudeInfo.Usage
 }
 

+ 1 - 0
relay/channel/xinference/constant.go

@@ -2,6 +2,7 @@ package xinference
 
 var ModelList = []string{
 	"bge-reranker-v2-m3",
+	"jina-reranker-v2",
 }
 
 var ChannelName = "xinference"