Browse Source

🐛 fix: refactor JSON unmarshalling across multiple handlers to use UnmarshalJson and UnmarshalJsonStr for consistency

This update replaces instances of DecodeJson and DecodeJsonStr with UnmarshalJson and UnmarshalJsonStr in various relay handlers, enhancing code consistency and clarity in JSON processing. The changes improve maintainability and align with recent refactoring efforts in the codebase.
CaIon 6 months ago
parent
commit
6b9237f868

+ 1 - 2
common/gin.go

@@ -2,7 +2,6 @@ package common
 
 
 import (
 import (
 	"bytes"
 	"bytes"
-	"encoding/json"
 	"github.com/gin-gonic/gin"
 	"github.com/gin-gonic/gin"
 	"io"
 	"io"
 	"strings"
 	"strings"
@@ -31,7 +30,7 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
 	}
 	}
 	contentType := c.Request.Header.Get("Content-Type")
 	contentType := c.Request.Header.Get("Content-Type")
 	if strings.HasPrefix(contentType, "application/json") {
 	if strings.HasPrefix(contentType, "application/json") {
-		err = json.Unmarshal(requestBody, &v)
+		err = UnmarshalJson(requestBody, &v)
 	} else {
 	} else {
 		// skip for now
 		// skip for now
 		// TODO: someday non json request have variant model, we will need to implementation this
 		// TODO: someday non json request have variant model, we will need to implementation this

+ 8 - 4
common/json.go

@@ -5,12 +5,16 @@ import (
 	"encoding/json"
 	"encoding/json"
 )
 )
 
 
-func DecodeJson(data []byte, v any) error {
-	return json.NewDecoder(bytes.NewReader(data)).Decode(v)
+func UnmarshalJson(data []byte, v any) error {
+	return json.Unmarshal(data, v)
 }
 }
 
 
-func DecodeJsonStr(data string, v any) error {
-	return DecodeJson(StringToByteSlice(data), v)
+func UnmarshalJsonStr(data string, v any) error {
+	return json.Unmarshal(StringToByteSlice(data), v)
+}
+
+func DecodeJson(reader *bytes.Reader, v any) error {
+	return json.NewDecoder(reader).Decode(v)
 }
 }
 
 
 func EncodeJson(v any) ([]byte, error) {
 func EncodeJson(v any) ([]byte, error) {

+ 1 - 1
dto/openai_request.go

@@ -66,7 +66,7 @@ type GeneralOpenAIRequest struct {
 func (r *GeneralOpenAIRequest) ToMap() map[string]any {
 func (r *GeneralOpenAIRequest) ToMap() map[string]any {
 	result := make(map[string]any)
 	result := make(map[string]any)
 	data, _ := common.EncodeJson(r)
 	data, _ := common.EncodeJson(r)
-	_ = common.DecodeJson(data, &result)
+	_ = common.UnmarshalJson(data, &result)
 	return result
 	return result
 }
 }
 
 

+ 7 - 7
relay/channel/claude/relay-claude.go

@@ -125,7 +125,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
 
 
 	if textRequest.Reasoning != nil {
 	if textRequest.Reasoning != nil {
 		var reasoning openrouter.RequestReasoning
 		var reasoning openrouter.RequestReasoning
-		if err := common.DecodeJson(textRequest.Reasoning, &reasoning); err != nil {
+		if err := common.UnmarshalJson(textRequest.Reasoning, &reasoning); err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
 
 
@@ -519,7 +519,7 @@ func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeRespons
 
 
 func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) *dto.OpenAIErrorWithStatusCode {
 func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) *dto.OpenAIErrorWithStatusCode {
 	var claudeResponse dto.ClaudeResponse
 	var claudeResponse dto.ClaudeResponse
-	err := common.DecodeJsonStr(data, &claudeResponse)
+	err := common.UnmarshalJsonStr(data, &claudeResponse)
 	if err != nil {
 	if err != nil {
 		common.SysError("error unmarshalling stream response: " + err.Error())
 		common.SysError("error unmarshalling stream response: " + err.Error())
 		return service.OpenAIErrorWrapper(err, "stream_response_error", http.StatusInternalServerError)
 		return service.OpenAIErrorWrapper(err, "stream_response_error", http.StatusInternalServerError)
@@ -619,7 +619,7 @@ func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
 
 
 func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data []byte, requestMode int) *dto.OpenAIErrorWithStatusCode {
 func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data []byte, requestMode int) *dto.OpenAIErrorWithStatusCode {
 	var claudeResponse dto.ClaudeResponse
 	var claudeResponse dto.ClaudeResponse
-	err := common.DecodeJson(data, &claudeResponse)
+	err := common.UnmarshalJson(data, &claudeResponse)
 	if err != nil {
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "unmarshal_claude_response_failed", http.StatusInternalServerError)
 		return service.OpenAIErrorWrapper(err, "unmarshal_claude_response_failed", http.StatusInternalServerError)
 	}
 	}
@@ -657,13 +657,14 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
 	case relaycommon.RelayFormatClaude:
 	case relaycommon.RelayFormatClaude:
 		responseData = data
 		responseData = data
 	}
 	}
-	c.Writer.Header().Set("Content-Type", "application/json")
-	c.Writer.WriteHeader(http.StatusOK)
-	_, err = c.Writer.Write(responseData)
+
+	common.IOCopyBytesGracefully(c, nil, responseData)
 	return nil
 	return nil
 }
 }
 
 
 func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
 func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+	defer common.CloseResponseBodyGracefully(resp)
+
 	claudeInfo := &ClaudeResponseInfo{
 	claudeInfo := &ClaudeResponseInfo{
 		ResponseId:   fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
 		ResponseId:   fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
 		Created:      common.GetTimestamp(),
 		Created:      common.GetTimestamp(),
@@ -675,7 +676,6 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *r
 	if err != nil {
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 	}
 	}
-	resp.Body.Close()
 	if common.DebugEnabled {
 	if common.DebugEnabled {
 		println("responseBody: ", string(responseBody))
 		println("responseBody: ", string(responseBody))
 	}
 	}

+ 6 - 12
relay/channel/gemini/relay-gemini-native.go

@@ -1,7 +1,6 @@
 package gemini
 package gemini
 
 
 import (
 import (
-	"encoding/json"
 	"io"
 	"io"
 	"net/http"
 	"net/http"
 	"one-api/common"
 	"one-api/common"
@@ -15,12 +14,13 @@ import (
 )
 )
 
 
 func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) {
 func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) {
+	defer common.CloseResponseBodyGracefully(resp)
+
 	// 读取响应体
 	// 读取响应体
 	responseBody, err := io.ReadAll(resp.Body)
 	responseBody, err := io.ReadAll(resp.Body)
 	if err != nil {
 	if err != nil {
 		return nil, service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
 		return nil, service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
 	}
 	}
-	common.CloseResponseBodyGracefully(resp)
 
 
 	if common.DebugEnabled {
 	if common.DebugEnabled {
 		println(string(responseBody))
 		println(string(responseBody))
@@ -28,7 +28,7 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela
 
 
 	// 解析为 Gemini 原生响应格式
 	// 解析为 Gemini 原生响应格式
 	var geminiResponse GeminiChatResponse
 	var geminiResponse GeminiChatResponse
-	err = common.DecodeJson(responseBody, &geminiResponse)
+	err = common.UnmarshalJson(responseBody, &geminiResponse)
 	if err != nil {
 	if err != nil {
 		return nil, service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
 		return nil, service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
 	}
 	}
@@ -51,18 +51,12 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela
 	}
 	}
 
 
 	// 直接返回 Gemini 原生格式的 JSON 响应
 	// 直接返回 Gemini 原生格式的 JSON 响应
-	jsonResponse, err := json.Marshal(geminiResponse)
+	jsonResponse, err := common.EncodeJson(geminiResponse)
 	if err != nil {
 	if err != nil {
 		return nil, service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
 		return nil, service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
 	}
 	}
 
 
-	// 设置响应头并写入响应
-	c.Writer.Header().Set("Content-Type", "application/json")
-	c.Writer.WriteHeader(resp.StatusCode)
-	_, err = c.Writer.Write(jsonResponse)
-	if err != nil {
-		return nil, service.OpenAIErrorWrapper(err, "write_response_failed", http.StatusInternalServerError)
-	}
+	common.IOCopyBytesGracefully(c, resp, jsonResponse)
 
 
 	return &usage, nil
 	return &usage, nil
 }
 }
@@ -77,7 +71,7 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info
 
 
 	helper.StreamScannerHandler(c, resp, info, func(data string) bool {
 	helper.StreamScannerHandler(c, resp, info, func(data string) bool {
 		var geminiResponse GeminiChatResponse
 		var geminiResponse GeminiChatResponse
-		err := common.DecodeJsonStr(data, &geminiResponse)
+		err := common.UnmarshalJsonStr(data, &geminiResponse)
 		if err != nil {
 		if err != nil {
 			common.LogError(c, "error unmarshalling stream response: "+err.Error())
 			common.LogError(c, "error unmarshalling stream response: "+err.Error())
 			return false
 			return false

+ 6 - 8
relay/channel/gemini/relay-gemini.go

@@ -801,7 +801,7 @@ func GeminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
 
 
 	helper.StreamScannerHandler(c, resp, info, func(data string) bool {
 	helper.StreamScannerHandler(c, resp, info, func(data string) bool {
 		var geminiResponse GeminiChatResponse
 		var geminiResponse GeminiChatResponse
-		err := common.DecodeJsonStr(data, &geminiResponse)
+		err := common.UnmarshalJsonStr(data, &geminiResponse)
 		if err != nil {
 		if err != nil {
 			common.LogError(c, "error unmarshalling stream response: "+err.Error())
 			common.LogError(c, "error unmarshalling stream response: "+err.Error())
 			return false
 			return false
@@ -871,7 +871,7 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
 		println(string(responseBody))
 		println(string(responseBody))
 	}
 	}
 	var geminiResponse GeminiChatResponse
 	var geminiResponse GeminiChatResponse
-	err = common.DecodeJson(responseBody, &geminiResponse)
+	err = common.UnmarshalJson(responseBody, &geminiResponse)
 	if err != nil {
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 	}
 	}
@@ -917,11 +917,12 @@ func GeminiChatHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
 }
 }
 
 
 func GeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 func GeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
+	defer common.CloseResponseBodyGracefully(resp)
+
 	responseBody, readErr := io.ReadAll(resp.Body)
 	responseBody, readErr := io.ReadAll(resp.Body)
 	if readErr != nil {
 	if readErr != nil {
 		return nil, service.OpenAIErrorWrapper(readErr, "read_response_body_failed", http.StatusInternalServerError)
 		return nil, service.OpenAIErrorWrapper(readErr, "read_response_body_failed", http.StatusInternalServerError)
 	}
 	}
-	_ = resp.Body.Close()
 
 
 	var geminiResponse GeminiEmbeddingResponse
 	var geminiResponse GeminiEmbeddingResponse
 	if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
 	if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
@@ -953,14 +954,11 @@ func GeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycomm
 	}
 	}
 	openAIResponse.Usage = *usage.(*dto.Usage)
 	openAIResponse.Usage = *usage.(*dto.Usage)
 
 
-	jsonResponse, jsonErr := json.Marshal(openAIResponse)
+	jsonResponse, jsonErr := common.EncodeJson(openAIResponse)
 	if jsonErr != nil {
 	if jsonErr != nil {
 		return nil, service.OpenAIErrorWrapper(jsonErr, "marshal_response_failed", http.StatusInternalServerError)
 		return nil, service.OpenAIErrorWrapper(jsonErr, "marshal_response_failed", http.StatusInternalServerError)
 	}
 	}
 
 
-	c.Writer.Header().Set("Content-Type", "application/json")
-	c.Writer.WriteHeader(resp.StatusCode)
-	_, _ = c.Writer.Write(jsonResponse)
-
+	common.IOCopyBytesGracefully(c, resp, jsonResponse)
 	return usage, nil
 	return usage, nil
 }
 }

+ 5 - 5
relay/channel/openai/relay-openai.go

@@ -33,7 +33,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
 	}
 	}
 
 
 	var lastStreamResponse dto.ChatCompletionsStreamResponse
 	var lastStreamResponse dto.ChatCompletionsStreamResponse
-	if err := common.DecodeJsonStr(data, &lastStreamResponse); err != nil {
+	if err := common.UnmarshalJsonStr(data, &lastStreamResponse); err != nil {
 		return err
 		return err
 	}
 	}
 
 
@@ -188,7 +188,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
 	if err != nil {
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 	}
 	}
-	err = common.DecodeJson(responseBody, &simpleResponse)
+	err = common.UnmarshalJson(responseBody, &simpleResponse)
 	if err != nil {
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 	}
 	}
@@ -368,7 +368,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
 				}
 				}
 
 
 				realtimeEvent := &dto.RealtimeEvent{}
 				realtimeEvent := &dto.RealtimeEvent{}
-				err = common.DecodeJson(message, realtimeEvent)
+				err = common.UnmarshalJson(message, realtimeEvent)
 				if err != nil {
 				if err != nil {
 					errChan <- fmt.Errorf("error unmarshalling message: %v", err)
 					errChan <- fmt.Errorf("error unmarshalling message: %v", err)
 					return
 					return
@@ -428,7 +428,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op
 				}
 				}
 				info.SetFirstResponseTime()
 				info.SetFirstResponseTime()
 				realtimeEvent := &dto.RealtimeEvent{}
 				realtimeEvent := &dto.RealtimeEvent{}
-				err = common.DecodeJson(message, realtimeEvent)
+				err = common.UnmarshalJson(message, realtimeEvent)
 				if err != nil {
 				if err != nil {
 					errChan <- fmt.Errorf("error unmarshalling message: %v", err)
 					errChan <- fmt.Errorf("error unmarshalling message: %v", err)
 					return
 					return
@@ -562,7 +562,7 @@ func OpenaiHandlerWithUsage(c *gin.Context, resp *http.Response, info *relaycomm
 	}
 	}
 
 
 	var usageResp dto.SimpleResponse
 	var usageResp dto.SimpleResponse
-	err = common.DecodeJson(responseBody, &usageResp)
+	err = common.UnmarshalJson(responseBody, &usageResp)
 	if err != nil {
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "parse_response_body_failed", http.StatusInternalServerError), nil
 		return service.OpenAIErrorWrapper(err, "parse_response_body_failed", http.StatusInternalServerError), nil
 	}
 	}

+ 2 - 2
relay/channel/openai/relay_responses.go

@@ -23,7 +23,7 @@ func OaiResponsesHandler(c *gin.Context, resp *http.Response, info *relaycommon.
 	if err != nil {
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 	}
 	}
-	err = common.DecodeJson(responseBody, &responsesResponse)
+	err = common.UnmarshalJson(responseBody, &responsesResponse)
 	if err != nil {
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 	}
 	}
@@ -66,7 +66,7 @@ func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relayc
 
 
 		// 检查当前数据是否包含 completed 状态和 usage 信息
 		// 检查当前数据是否包含 completed 状态和 usage 信息
 		var streamResponse dto.ResponsesStreamResponse
 		var streamResponse dto.ResponsesStreamResponse
-		if err := common.DecodeJsonStr(data, &streamResponse); err == nil {
+		if err := common.UnmarshalJsonStr(data, &streamResponse); err == nil {
 			sendResponsesStreamData(c, streamResponse, data)
 			sendResponsesStreamData(c, streamResponse, data)
 			switch streamResponse.Type {
 			switch streamResponse.Type {
 			case "response.completed":
 			case "response.completed":

+ 1 - 1
relay/channel/xai/text.go

@@ -82,7 +82,7 @@ func xAIHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo
 
 
 	responseBody, err := io.ReadAll(resp.Body)
 	responseBody, err := io.ReadAll(resp.Body)
 	var response *dto.SimpleResponse
 	var response *dto.SimpleResponse
-	err = common.DecodeJson(responseBody, &response)
+	err = common.UnmarshalJson(responseBody, &response)
 	if err != nil {
 	if err != nil {
 		common.SysError("error unmarshalling stream response: " + err.Error())
 		common.SysError("error unmarshalling stream response: " + err.Error())
 		return nil, nil
 		return nil, nil

+ 2 - 2
relay/common_handler/rerank.go

@@ -23,7 +23,7 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
 	var jinaResp dto.RerankResponse
 	var jinaResp dto.RerankResponse
 	if info.ChannelType == common.ChannelTypeXinference {
 	if info.ChannelType == common.ChannelTypeXinference {
 		var xinRerankResponse xinference.XinRerankResponse
 		var xinRerankResponse xinference.XinRerankResponse
-		err = common.DecodeJson(responseBody, &xinRerankResponse)
+		err = common.UnmarshalJson(responseBody, &xinRerankResponse)
 		if err != nil {
 		if err != nil {
 			return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 			return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 		}
 		}
@@ -58,7 +58,7 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
 			},
 			},
 		}
 		}
 	} else {
 	} else {
-		err = common.DecodeJson(responseBody, &jinaResp)
+		err = common.UnmarshalJson(responseBody, &jinaResp)
 		if err != nil {
 		if err != nil {
 			return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 			return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 		}
 		}