Browse Source

✨ feat(adaptor): refactor response handlers to return usage first and improve error handling

CaIon 6 months ago
parent
commit
52a5e58f0c

+ 3 - 3
relay/channel/cohere/adaptor.go

@@ -74,12 +74,12 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
 
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
 	if info.RelayMode == constant.RelayModeRerank {
-		err, usage = cohereRerankHandler(c, resp, info)
+		usage, err = cohereRerankHandler(c, resp, info)
 	} else {
 		if info.IsStream {
-			err, usage = cohereStreamHandler(c, info, resp)
+			usage, err = cohereStreamHandler(c, info, resp) // TODO: fix this
 		} else {
-			err, usage = cohereHandler(c, info, resp)
+			usage, err = cohereHandler(c, info, resp)
 		}
 	}
 	return

+ 13 - 13
relay/channel/cohere/relay-cohere.go

@@ -78,7 +78,7 @@ func stopReasonCohere2OpenAI(reason string) string {
 	}
 }
 
-func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
+func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
 	responseId := helper.GetResponseID(c)
 	createdTime := common.GetTimestamp()
 	usage := &dto.Usage{}
@@ -166,20 +166,20 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
 	if usage.PromptTokens == 0 {
 		usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 	}
-	return nil, usage
+	return usage, nil
 }
 
-func cohereHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
+func cohereHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
 	createdTime := common.GetTimestamp()
 	responseBody, err := io.ReadAll(resp.Body)
 	if err != nil {
-		return types.NewError(err, types.ErrorCodeBadResponseBody), nil
+		return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
 	}
 	common.CloseResponseBodyGracefully(resp)
 	var cohereResp CohereResponseResult
 	err = json.Unmarshal(responseBody, &cohereResp)
 	if err != nil {
-		return types.NewError(err, types.ErrorCodeBadResponseBody), nil
+		return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
 	}
 	usage := dto.Usage{}
 	usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
@@ -203,24 +203,24 @@ func cohereHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
 
 	jsonResponse, err := json.Marshal(openaiResp)
 	if err != nil {
-		return types.NewError(err, types.ErrorCodeBadResponseBody), nil
+		return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
 	}
 	c.Writer.Header().Set("Content-Type", "application/json")
 	c.Writer.WriteHeader(resp.StatusCode)
-	_, err = c.Writer.Write(jsonResponse)
-	return nil, &usage
+	_, _ = c.Writer.Write(jsonResponse)
+	return &usage, nil
 }
 
-func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
+func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
 	responseBody, err := io.ReadAll(resp.Body)
 	if err != nil {
-		return types.NewError(err, types.ErrorCodeBadResponseBody), nil
+		return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
 	}
 	common.CloseResponseBodyGracefully(resp)
 	var cohereResp CohereRerankResponseResult
 	err = json.Unmarshal(responseBody, &cohereResp)
 	if err != nil {
-		return types.NewError(err, types.ErrorCodeBadResponseBody), nil
+		return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
 	}
 	usage := dto.Usage{}
 	if cohereResp.Meta.BilledUnits.InputTokens == 0 {
@@ -239,10 +239,10 @@ func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.
 
 	jsonResponse, err := json.Marshal(rerankResp)
 	if err != nil {
-		return types.NewError(err, types.ErrorCodeBadResponseBody), nil
+		return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
 	}
 	c.Writer.Header().Set("Content-Type", "application/json")
 	c.Writer.WriteHeader(resp.StatusCode)
 	_, err = c.Writer.Write(jsonResponse)
-	return nil, &usage
+	return &usage, nil
 }

+ 2 - 2
relay/channel/coze/adaptor.go

@@ -98,9 +98,9 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *common.RelayInfo, requestBody
 // DoResponse implements channel.Adaptor.
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *common.RelayInfo) (usage any, err *types.NewAPIError) {
 	if info.IsStream {
-		err, usage = cozeChatStreamHandler(c, info, resp)
+		usage, err = cozeChatStreamHandler(c, info, resp)
 	} else {
-		err, usage = cozeChatHandler(c, info, resp)
+		usage, err = cozeChatHandler(c, info, resp)
 	}
 	return
 }

+ 9 - 9
relay/channel/coze/relay-coze.go

@@ -44,10 +44,10 @@ func convertCozeChatRequest(c *gin.Context, request dto.GeneralOpenAIRequest) *C
 	return cozeRequest
 }
 
-func cozeChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
+func cozeChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
 	responseBody, err := io.ReadAll(resp.Body)
 	if err != nil {
-		return types.NewError(err, types.ErrorCodeBadResponseBody), nil
+		return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
 	}
 	common.CloseResponseBodyGracefully(resp)
 	// convert coze response to openai response
@@ -56,10 +56,10 @@ func cozeChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res
 	response.Model = info.UpstreamModelName
 	err = json.Unmarshal(responseBody, &cozeResponse)
 	if err != nil {
-		return types.NewError(err, types.ErrorCodeBadResponseBody), nil
+		return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
 	}
 	if cozeResponse.Code != 0 {
-		return types.NewError(errors.New(cozeResponse.Msg), types.ErrorCodeBadResponseBody), nil
+		return nil, types.NewError(errors.New(cozeResponse.Msg), types.ErrorCodeBadResponseBody)
 	}
 	// 从上下文获取 usage
 	var usage dto.Usage
@@ -86,16 +86,16 @@ func cozeChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Res
 	}
 	jsonResponse, err := json.Marshal(response)
 	if err != nil {
-		return types.NewError(err, types.ErrorCodeBadResponseBody), nil
+		return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
 	}
 	c.Writer.Header().Set("Content-Type", "application/json")
 	c.Writer.WriteHeader(resp.StatusCode)
 	_, _ = c.Writer.Write(jsonResponse)
 
-	return nil, &usage
+	return &usage, nil
 }
 
-func cozeChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
+func cozeChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
 	scanner := bufio.NewScanner(resp.Body)
 	scanner.Split(bufio.ScanLines)
 	helper.SetEventStreamHeaders(c)
@@ -136,7 +136,7 @@ func cozeChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *ht
 	}
 
 	if err := scanner.Err(); err != nil {
-		return types.NewError(err, types.ErrorCodeBadResponseBody), nil
+		return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
 	}
 	helper.Done(c)
 
@@ -144,7 +144,7 @@ func cozeChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *ht
 		usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, c.GetInt("coze_input_count"))
 	}
 
-	return nil, usage
+	return usage, nil
 }
 
 func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) {

+ 1 - 1
relay/channel/ollama/relay-ollama.go

@@ -96,7 +96,7 @@ func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
 		return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
 	}
 	if ollamaEmbeddingResponse.Error != "" {
-		return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
+		return nil, types.NewError(fmt.Errorf("ollama error: %s", ollamaEmbeddingResponse.Error), types.ErrorCodeBadResponseBody)
 	}
 	flattenedEmbeddings := flattenEmbeddings(ollamaEmbeddingResponse.Embedding)
 	data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1)