Browse Source

🐛 fix: refactor JSON unmarshalling across multiple handlers to use UnmarshalJson and UnmarshalJsonStr for consistency

This update replaces instances of DecodeJson and DecodeJsonStr with UnmarshalJson and UnmarshalJsonStr in various relay handlers, enhancing code consistency and clarity in JSON processing. The changes improve maintainability and align with recent refactoring efforts in the codebase.
CaIon 6 months ago
parent
commit
6b9237f868

+ 1 - 2
common/gin.go

@@ -2,7 +2,6 @@ package common
 
 import (
 	"bytes"
-	"encoding/json"
 	"github.com/gin-gonic/gin"
 	"io"
 	"strings"
@@ -31,7 +30,7 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
 	}
 	contentType := c.Request.Header.Get("Content-Type")
 	if strings.HasPrefix(contentType, "application/json") {
-		err = json.Unmarshal(requestBody, &v)
+		err = UnmarshalJson(requestBody, &v)
 	} else {
 		// skip for now
 		// TODO: someday non json request have variant model, we will need to implementation this

+ 8 - 4
common/json.go

@@ -5,12 +5,16 @@ import (
 	"encoding/json"
 )
 
-func DecodeJson(data []byte, v any) error {
-	return json.NewDecoder(bytes.NewReader(data)).Decode(v)
+func UnmarshalJson(data []byte, v any) error {
+	return json.Unmarshal(data, v)
 }
 
-func DecodeJsonStr(data string, v any) error {
-	return DecodeJson(StringToByteSlice(data), v)
+func UnmarshalJsonStr(data string, v any) error {
+	return json.Unmarshal(StringToByteSlice(data), v)
+}
+
+func DecodeJson(reader *bytes.Reader, v any) error {
+	return json.NewDecoder(reader).Decode(v)
 }
 
 func EncodeJson(v any) ([]byte, error) {

+ 1 - 1
dto/openai_request.go

@@ -66,7 +66,7 @@ type GeneralOpenAIRequest struct {
 func (r *GeneralOpenAIRequest) ToMap() map[string]any {
 	result := make(map[string]any)
 	data, _ := common.EncodeJson(r)
-	_ = common.DecodeJson(data, &result)
+	_ = common.UnmarshalJson(data, &result)
 	return result
 }
 

+ 7 - 7
relay/channel/claude/relay-claude.go

@@ -125,7 +125,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
 
 	if textRequest.Reasoning != nil {
 		var reasoning openrouter.RequestReasoning
-		if err := common.DecodeJson(textRequest.Reasoning, &reasoning); err != nil {
+		if err := common.UnmarshalJson(textRequest.Reasoning, &reasoning); err != nil {
 			return nil, err
 		}
 
@@ -519,7 +519,7 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeRespons
 
 func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) *dto.OpenAIErrorWithStatusCode {
 	var claudeResponse dto.ClaudeResponse
-	err := common.DecodeJsonStr(data, &claudeResponse)
+	err := common.UnmarshalJsonStr(data, &claudeResponse)
 	if err != nil {
 		common.SysError("error unmarshalling stream response: " + err.Error())
 		return service.OpenAIErrorWrapper(err, "stream_response_error", http.StatusInternalServerError)
@@ -619,7 +619,7 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
 
 func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data []byte, requestMode int) *dto.OpenAIErrorWithStatusCode {
 	var claudeResponse dto.ClaudeResponse
-	err := common.DecodeJson(data, &claudeResponse)
+	err := common.UnmarshalJson(data, &claudeResponse)
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "unmarshal_claude_response_failed", http.StatusInternalServerError)
 	}
@@ -657,13 +657,14 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
 	case relaycommon.RelayFormatClaude:
 		responseData = data
 	}
-	c.Writer.Header().Set("Content-Type", "application/json")
-	c.Writer.WriteHeader(http.StatusOK)
-	_, err = c.Writer.Write(responseData)
+
+	common.IOCopyBytesGracefully(c, nil, responseData)
 	return nil
 }
 
 func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+	defer common.CloseResponseBodyGracefully(resp)
+
 	claudeInfo := &ClaudeResponseInfo{
 		ResponseId:   fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
 		Created:      common.GetTimestamp(),
@@ -675,7 +676,6 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *r
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 	}
-	resp.Body.Close()
 	if common.DebugEnabled {
 		println("responseBody: ", string(responseBody))
 	}

+ 6 - 12
relay/channel/gemini/relay-gemini-native.go

@@ -1,7 +1,6 @@
 package gemini
 
 import (
-	"encoding/json"
 	"io"
 	"net/http"
 	"one-api/common"
@@ -15,12 +14,13 @@ import (
 )
 
 func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) {
+	defer common.CloseResponseBodyGracefully(resp)
+
 	// 读取响应体
 	responseBody, err := io.ReadAll(resp.Body)
 	if err != nil {
 		return nil, service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
 	}
-	common.CloseResponseBodyGracefully(resp)
 
 	if common.DebugEnabled {
 		println(string(responseBody))
@@ -28,7 +28,7 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela
 
 	// 解析为 Gemini 原生响应格式
 	var geminiResponse GeminiChatResponse
-	err = common.DecodeJson(responseBody, &geminiResponse)
+	err = common.UnmarshalJson(responseBody, &geminiResponse)
 	if err != nil {
 		return nil, service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
 	}
@@ -51,18 +51,12 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela
 	}
 
 	// 直接返回 Gemini 原生格式的 JSON 响应
-	jsonResponse, err := json.Marshal(geminiResponse)
+	jsonResponse, err := common.EncodeJson(geminiResponse)
 	if err != nil {
 		return nil, service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
 	}
 
-	// 设置响应头并写入响应
-	c.Writer.Header().Set("Content-Type", "application/json")
-	c.Writer.WriteHeader(resp.StatusCode)
-	_, err = c.Writer.Write(jsonResponse)
-	if err != nil {
-		return nil, service.OpenAIErrorWrapper(err, "write_response_failed", http.StatusInternalServerError)
-	}
+	common.IOCopyBytesGracefully(c, resp, jsonResponse)
 
 	return &usage, nil
 }
@@ -77,7 +71,7 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info
 
 	helper.StreamScannerHandler(c, resp, info, func(data string) bool {
 		var geminiResponse GeminiChatResponse
-		err := common.DecodeJsonStr(data, &geminiResponse)
+		err := common.UnmarshalJsonStr(data, &geminiResponse)
 		if err != nil {
 			common.LogError(c, "error unmarshalling stream response: "+err.Error())
 			return false

+ 6 - 8
relay/channel/gemini/relay-gemini.go

@@ -801,7 +801,7 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
 
 	helper.StreamScannerHandler(c, resp, info, func(data string) bool {
 		var geminiResponse GeminiChatResponse
-		err := common.DecodeJsonStr(data, &geminiResponse)
+		err := common.UnmarshalJsonStr(data, &geminiResponse)
 		if err != nil {
 			common.LogError(c, "error unmarshalling stream response: "+err.Error())
 			return false
@@ -871,7 +871,7 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
 		println(string(responseBody))
 	}
 	var geminiResponse GeminiChatResponse
-	err = common.DecodeJson(responseBody, &geminiResponse)
+	err = common.UnmarshalJson(responseBody, &geminiResponse)
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 	}
@@ -917,11 +917,12 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
 }
 
 func GeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
+	defer common.CloseResponseBodyGracefully(resp)
+
 	responseBody, readErr := io.ReadAll(resp.Body)
 	if readErr != nil {
 		return nil, service.OpenAIErrorWrapper(readErr, "read_response_body_failed", http.StatusInternalServerError)
 	}
-	_ = resp.Body.Close()
 
 	var geminiResponse GeminiEmbeddingResponse
 	if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
@@ -953,14 +954,11 @@ func GeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycomm
 	}
 	openAIResponse.Usage = *usage.(*dto.Usage)
 
-	jsonResponse, jsonErr := json.Marshal(openAIResponse)
+	jsonResponse, jsonErr := common.EncodeJson(openAIResponse)
 	if jsonErr != nil {
 		return nil, service.OpenAIErrorWrapper(jsonErr, "marshal_response_failed", http.StatusInternalServerError)
 	}
 
-	c.Writer.Header().Set("Content-Type", "application/json")
-	c.Writer.WriteHeader(resp.StatusCode)
-	_, _ = c.Writer.Write(jsonResponse)
-
+	common.IOCopyBytesGracefully(c, resp, jsonResponse)
 	return usage, nil
 }

+ 5 - 5
relay/channel/openai/relay-openai.go

@@ -33,7 +33,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
 	}
 
 	var lastStreamResponse dto.ChatCompletionsStreamResponse
-	if err := common.DecodeJsonStr(data, &lastStreamResponse); err != nil {
+	if err := common.UnmarshalJsonStr(data, &lastStreamResponse); err != nil {
 		return err
 	}
 
@@ -188,7 +188,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 	}
-	err = common.DecodeJson(responseBody, &simpleResponse)
+	err = common.UnmarshalJson(responseBody, &simpleResponse)
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 	}
@@ -368,7 +368,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
 				}
 
 				realtimeEvent := &dto.RealtimeEvent{}
-				err = common.DecodeJson(message, realtimeEvent)
+				err = common.UnmarshalJson(message, realtimeEvent)
 				if err != nil {
 					errChan <- fmt.Errorf("error unmarshalling message: %v", err)
 					return
@@ -428,7 +428,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
 				}
 				info.SetFirstResponseTime()
 				realtimeEvent := &dto.RealtimeEvent{}
-				err = common.DecodeJson(message, realtimeEvent)
+				err = common.UnmarshalJson(message, realtimeEvent)
 				if err != nil {
 					errChan <- fmt.Errorf("error unmarshalling message: %v", err)
 					return
@@ -562,7 +562,7 @@ func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycomm
 	}
 
 	var usageResp dto.SimpleResponse
-	err = common.DecodeJson(responseBody, &usageResp)
+	err = common.UnmarshalJson(responseBody, &usageResp)
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "parse_response_body_failed", http.StatusInternalServerError), nil
 	}

+ 2 - 2
relay/channel/openai/relay_responses.go

@@ -23,7 +23,7 @@ func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 	}
-	err = common.DecodeJson(responseBody, &responsesResponse)
+	err = common.UnmarshalJson(responseBody, &responsesResponse)
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 	}
@@ -66,7 +66,7 @@ func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relayc
 
 		// 检查当前数据是否包含 completed 状态和 usage 信息
 		var streamResponse dto.ResponsesStreamResponse
-		if err := common.DecodeJsonStr(data, &streamResponse); err == nil {
+		if err := common.UnmarshalJsonStr(data, &streamResponse); err == nil {
 			sendResponsesStreamData(c, streamResponse, data)
 			switch streamResponse.Type {
 			case "response.completed":

+ 1 - 1
relay/channel/xai/text.go

@@ -82,7 +82,7 @@ func xAIHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo
 
 	responseBody, err := io.ReadAll(resp.Body)
 	var response *dto.SimpleResponse
-	err = common.DecodeJson(responseBody, &response)
+	err = common.UnmarshalJson(responseBody, &response)
 	if err != nil {
 		common.SysError("error unmarshalling stream response: " + err.Error())
 		return nil, nil

+ 2 - 2
relay/common_handler/rerank.go

@@ -23,7 +23,7 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
 	var jinaResp dto.RerankResponse
 	if info.ChannelType == common.ChannelTypeXinference {
 		var xinRerankResponse xinference.XinRerankResponse
-		err = common.DecodeJson(responseBody, &xinRerankResponse)
+		err = common.UnmarshalJson(responseBody, &xinRerankResponse)
 		if err != nil {
 			return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 		}
@@ -58,7 +58,7 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
 			},
 		}
 	} else {
-		err = common.DecodeJson(responseBody, &jinaResp)
+		err = common.UnmarshalJson(responseBody, &jinaResp)
 		if err != nil {
 			return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 		}