Browse Source

refactor: Enhance error handling in AWS and Claude response processing by updating function signatures and improving error propagation

[email protected] 9 months ago
parent
commit
ee302c063c
2 changed files with 18 additions and 18 deletions
  1. 5 6
      relay/channel/aws/relay-aws.go
  2. 13 12
      relay/channel/claude/relay-claude.go

+ 5 - 6
relay/channel/aws/relay-aws.go

@@ -10,7 +10,6 @@ import (
 	"one-api/dto"
 	"one-api/relay/channel/claude"
 	relaycommon "one-api/relay/common"
-	"one-api/service"
 	"strings"
 
 	"github.com/aws/aws-sdk-go-v2/aws"
@@ -151,7 +150,10 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 		switch v := event.(type) {
 		case *types.ResponseStreamMemberChunk:
 			info.SetFirstResponseTime()
-			claude.HandleResponseData(c, info, claudeInfo, string(v.Value.Bytes), RequestModeMessage)
+			err = claude.HandleResponseData(c, info, claudeInfo, string(v.Value.Bytes), RequestModeMessage)
+			if err != nil {
+				return wrapErr(err), nil
+			}
 		case *types.UnknownUnionMember:
 			fmt.Println("unknown tag:", v.Tag)
 			return wrapErr(errors.New("unknown response type")), nil
@@ -164,10 +166,7 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 	claude.HandleFinalResponse(c, info, claudeInfo, RequestModeMessage)
 
 	if resp != nil {
-		err = resp.Body.Close()
-		if err != nil {
-			return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
-		}
+		resp.Body.Close()
 	}
 	return nil, claudeInfo.Usage
 }

+ 13 - 12
relay/channel/claude/relay-claude.go

@@ -479,12 +479,12 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeRespons
 	return true
 }
 
-func HandleResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) bool {
+func HandleResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) error {
 	var claudeResponse dto.ClaudeResponse
 	err := json.NewDecoder(bytes.NewReader(common.StringToByteSlice(data))).Decode(&claudeResponse)
 	if err != nil {
 		common.SysError("error unmarshalling stream response: " + err.Error())
-		return false
+		return fmt.Errorf("error unmarshalling stream aws response: %w", err)
 	}
 	if info.RelayFormat == relaycommon.RelayFormatClaude {
 		if requestMode == RequestModeCompletion {
@@ -510,16 +510,10 @@ func HandleResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo
 		}
 		helper.ClaudeChunkData(c, claudeResponse, data)
 	} else if info.RelayFormat == relaycommon.RelayFormatOpenAI {
-		err := json.NewDecoder(bytes.NewReader(common.StringToByteSlice(data))).Decode(&claudeResponse)
-		if err != nil {
-			common.SysError("error unmarshalling stream response: " + err.Error())
-			return false
-		}
-
 		response := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
 
 		if !FormatClaudeResponseInfo(requestMode, &claudeResponse, response, claudeInfo) {
-			return true
+			return nil
 		}
 
 		err = helper.ObjectData(c, response)
@@ -527,7 +521,7 @@ func HandleResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo
 			common.LogError(c, "send_stream_response_failed: "+err.Error())
 		}
 	}
-	return true
+	return nil
 }
 
 func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
@@ -573,10 +567,17 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
 		ResponseText: strings.Builder{},
 		Usage:        &dto.Usage{},
 	}
-
+	var err error
 	helper.StreamScannerHandler(c, resp, info, func(data string) bool {
-		return HandleResponseData(c, info, claudeInfo, data, requestMode)
+		err = HandleResponseData(c, info, claudeInfo, data, requestMode)
+		if err != nil {
+			return false
+		}
+		return true
 	})
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "stream_response_error", http.StatusInternalServerError), nil
+	}
 
 	HandleFinalResponse(c, info, claudeInfo, requestMode)