Просмотр исходного кода

🐛 fix: refactor response body handling in multiple relay handlers to utilize IOCopyBytesGracefully

CaIon 6 месяцев назад
Родитель
Сommit
1f4cf07b63

+ 18 - 17
common/http.go

@@ -3,9 +3,10 @@ package common
 import (
 	"bytes"
 	"fmt"
-	"github.com/gin-gonic/gin"
 	"io"
 	"net/http"
+
+	"github.com/gin-gonic/gin"
 )
 
 func CloseResponseBodyGracefully(httpResponse *http.Response) {
@@ -19,37 +20,37 @@ func CloseResponseBodyGracefully(httpResponse *http.Response) {
 }
 
 func IOCopyBytesGracefully(c *gin.Context, src *http.Response, data []byte) {
-	if src == nil || src.Body == nil {
-		return
-	}
-
-	defer CloseResponseBodyGracefully(src)
-
 	if c.Writer == nil {
 		return
 	}
 
-	src.Body = io.NopCloser(bytes.NewBuffer(data))
+	body := io.NopCloser(bytes.NewBuffer(data))
 
 	// We shouldn't set the header before we parse the response body, because the parse part may fail.
 	// And then we will have to send an error response, but in this case, the header has already been set.
 	// So the httpClient will be confused by the response.
 	// For example, Postman will report error, and we cannot check the response at all.
-	for k, v := range src.Header {
-		// avoid setting Content-Length
-		if k == "Content-Length" {
-			continue
+	if src != nil {
+		for k, v := range src.Header {
+			// avoid setting Content-Length
+			if k == "Content-Length" {
+				continue
+			}
+			c.Writer.Header().Set(k, v[0])
 		}
-		c.Writer.Header().Set(k, v[0])
 	}
 
-	// set Content-Length header manually
+	// set Content-Length header manually BEFORE calling WriteHeader
 	c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
 
-	c.Writer.WriteHeader(src.StatusCode)
-	c.Writer.WriteHeaderNow()
+	// Write header with status code (this sends the headers)
+	if src != nil {
+		c.Writer.WriteHeader(src.StatusCode)
+	} else {
+		c.Writer.WriteHeader(http.StatusOK)
+	}
 
-	_, err := io.Copy(c.Writer, src.Body)
+	_, err := io.Copy(c.Writer, body)
 	if err != nil {
 		LogError(c, fmt.Sprintf("failed to copy response body: %s", err.Error()))
 	}

+ 1 - 23
relay/channel/ollama/relay-ollama.go

@@ -1,7 +1,6 @@
 package ollama
 
 import (
-	"bytes"
 	"encoding/json"
 	"fmt"
 	"github.com/gin-gonic/gin"
@@ -118,28 +117,7 @@ func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens in
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 	}
-	resp.Body = io.NopCloser(bytes.NewBuffer(doResponseBody))
-	// We shouldn't set the header before we parse the response body, because the parse part may fail.
-	// And then we will have to send an error response, but in this case, the header has already been set.
-	// So the httpClient will be confused by the response.
-	// For example, Postman will report error, and we cannot check the response at all.
-	// Copy headers
-	for k, v := range resp.Header {
-		// 删除任何现有的相同头部,以防止重复添加头部
-		c.Writer.Header().Del(k)
-		for _, vv := range v {
-			c.Writer.Header().Add(k, vv)
-		}
-	}
-	// reset content length
-	c.Writer.Header().Del("Content-Length")
-	c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(doResponseBody)))
-	c.Writer.WriteHeader(resp.StatusCode)
-	_, err = io.Copy(c.Writer, resp.Body)
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
-	}
-	common.CloseResponseBodyGracefully(resp)
+	common.IOCopyBytesGracefully(c, resp, doResponseBody)
 	return nil, usage
 }
 

+ 6 - 6
relay/channel/openai/relay-openai.go

@@ -181,12 +181,13 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 }
 
 func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+	defer common.CloseResponseBodyGracefully(resp)
+
 	var simpleResponse dto.OpenAITextResponse
 	responseBody, err := io.ReadAll(resp.Body)
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 	}
-	common.CloseResponseBodyGracefully(resp)
 	err = common.DecodeJson(responseBody, &simpleResponse)
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
@@ -264,6 +265,8 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 }
 
 func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+	defer common.CloseResponseBodyGracefully(resp)
+
 	// count tokens by audio file duration
 	audioTokens, err := countAudioTokens(c)
 	if err != nil {
@@ -273,8 +276,6 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 	}
-	common.CloseResponseBodyGracefully(resp)
-
 	// 写入新的 response body
 	common.IOCopyBytesGracefully(c, resp, responseBody)
 
@@ -553,6 +554,8 @@ func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.R
 }
 
 func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+	defer common.CloseResponseBodyGracefully(resp)
+
 	responseBody, err := io.ReadAll(resp.Body)
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
@@ -564,9 +567,6 @@ func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycomm
 		return service.OpenAIErrorWrapper(err, "parse_response_body_failed", http.StatusInternalServerError), nil
 	}
 
-	// 关闭旧的 response body(已被读取,再次读取会导致错误)
-	common.CloseResponseBodyGracefully(resp)
-
 	// 写入新的 response body
 	common.IOCopyBytesGracefully(c, resp, responseBody)
 

+ 5 - 18
relay/channel/openai/relay_responses.go

@@ -1,7 +1,6 @@
 package openai
 
 import (
-	"bytes"
 	"fmt"
 	"io"
 	"net/http"
@@ -16,13 +15,14 @@ import (
 )
 
 func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+	defer common.CloseResponseBodyGracefully(resp)
+
 	// read response body
 	var responsesResponse dto.OpenAIResponsesResponse
 	responseBody, err := io.ReadAll(resp.Body)
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 	}
-	common.CloseResponseBodyGracefully(resp)
 	err = common.DecodeJson(responseBody, &responsesResponse)
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
@@ -38,22 +38,9 @@ func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.
 		}, nil
 	}
 
-	// reset response body
-	resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
-	// We shouldn't set the header before we parse the response body, because the parse part may fail.
-	// And then we will have to send an error response, but in this case, the header has already been set.
-	// So the httpClient will be confused by the response.
-	// For example, Postman will report error, and we cannot check the response at all.
-	for k, v := range resp.Header {
-		c.Writer.Header().Set(k, v[0])
-	}
-	c.Writer.WriteHeader(resp.StatusCode)
-	// copy response body
-	_, err = io.Copy(c.Writer, resp.Body)
-	if err != nil {
-		common.SysError("error copying response body: " + err.Error())
-	}
-	resp.Body.Close()
+	// 写入新的 response body
+	common.IOCopyBytesGracefully(c, resp, responseBody)
+
 	// compute usage
 	usage := dto.Usage{}
 	usage.PromptTokens = responsesResponse.Usage.InputTokens

+ 6 - 15
relay/channel/xai/text.go

@@ -1,9 +1,7 @@
 package xai
 
 import (
-	"bytes"
 	"encoding/json"
-	"github.com/gin-gonic/gin"
 	"io"
 	"net/http"
 	"one-api/common"
@@ -13,6 +11,8 @@ import (
 	"one-api/relay/helper"
 	"one-api/service"
 	"strings"
+
+	"github.com/gin-gonic/gin"
 )
 
 func streamResponseXAI2OpenAI(xAIResp *dto.ChatCompletionsStreamResponse, usage *dto.Usage) *dto.ChatCompletionsStreamResponse {
@@ -78,8 +78,10 @@ func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
 }
 
 func xAIHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+	defer common.CloseResponseBodyGracefully(resp)
+
 	responseBody, err := io.ReadAll(resp.Body)
-	var response *dto.TextResponse
+	var response *dto.SimpleResponse
 	err = common.DecodeJson(responseBody, &response)
 	if err != nil {
 		common.SysError("error unmarshalling stream response: " + err.Error())
@@ -95,18 +97,7 @@ func xAIHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo
 		return nil, nil
 	}
 
-	// set new body
-	resp.Body = io.NopCloser(bytes.NewBuffer(encodeJson))
-
-	for k, v := range resp.Header {
-		c.Writer.Header().Set(k, v[0])
-	}
-	c.Writer.WriteHeader(resp.StatusCode)
-	_, err = io.Copy(c.Writer, resp.Body)
-	if err != nil {
-		return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
-	}
-	common.CloseResponseBodyGracefully(resp)
+	common.IOCopyBytesGracefully(c, resp, encodeJson)
 
 	return nil, &response.Usage
 }

+ 1 - 4
relay/relay-mj.go

@@ -279,10 +279,7 @@ func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse {
 	if err != nil {
 		return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed")
 	}
-	_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
-	if err != nil {
-		return service.MidjourneyErrorWrapper(constant.MjRequestError, "copy_response_body_failed")
-	}
+	common.IOCopyBytesGracefully(c, nil, respBody)
 	return nil
 }