| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475 |
- package common_handler
- import (
- "io"
- "net/http"
- "github.com/QuantumNous/new-api/common"
- "github.com/QuantumNous/new-api/constant"
- "github.com/QuantumNous/new-api/dto"
- "github.com/QuantumNous/new-api/relay/channel/xinference"
- relaycommon "github.com/QuantumNous/new-api/relay/common"
- "github.com/QuantumNous/new-api/service"
- "github.com/QuantumNous/new-api/types"
- "github.com/gin-gonic/gin"
- )
- func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
- }
- service.CloseResponseBodyGracefully(resp)
- if common.DebugEnabled {
- println("reranker response body: ", string(responseBody))
- }
- var jinaResp dto.RerankResponse
- if info.ChannelType == constant.ChannelTypeXinference {
- var xinRerankResponse xinference.XinRerankResponse
- err = common.Unmarshal(responseBody, &xinRerankResponse)
- if err != nil {
- return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
- }
- jinaRespResults := make([]dto.RerankResponseResult, len(xinRerankResponse.Results))
- for i, result := range xinRerankResponse.Results {
- respResult := dto.RerankResponseResult{
- Index: result.Index,
- RelevanceScore: result.RelevanceScore,
- }
- if info.ReturnDocuments {
- var document any
- if result.Document != nil {
- if doc, ok := result.Document.(string); ok {
- if doc == "" {
- document = info.Documents[result.Index]
- } else {
- document = doc
- }
- } else {
- document = result.Document
- }
- }
- respResult.Document = document
- }
- jinaRespResults[i] = respResult
- }
- jinaResp = dto.RerankResponse{
- Results: jinaRespResults,
- Usage: dto.Usage{
- PromptTokens: info.PromptTokens,
- TotalTokens: info.PromptTokens,
- },
- }
- } else {
- err = common.Unmarshal(responseBody, &jinaResp)
- if err != nil {
- return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
- }
- jinaResp.Usage.PromptTokens = jinaResp.Usage.TotalTokens
- }
- c.Writer.Header().Set("Content-Type", "application/json")
- c.JSON(http.StatusOK, jinaResp)
- return &jinaResp.Usage, nil
- }
|