Просмотр исходного кода

feat: Add Jina reranking support for OpenAI adaptor

[email protected] 10 месяцев назад
Родитель
Сommit
287caf8e38

+ 1 - 1
relay/channel/jina/adaptor.go

@@ -61,7 +61,7 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
 
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 	if info.RelayMode == constant.RelayModeRerank {
-		err, usage = jinaRerankHandler(c, resp)
+		err, usage = JinaRerankHandler(c, resp)
 	} else if info.RelayMode == constant.RelayModeEmbeddings {
 		err, usage = jinaEmbeddingHandler(c, resp)
 	}

+ 1 - 1
relay/channel/jina/relay-jina.go

@@ -9,7 +9,7 @@ import (
 	"one-api/service"
 )
 
-func jinaRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func JinaRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
 	responseBody, err := io.ReadAll(resp.Body)
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil

+ 4 - 1
relay/channel/openai/adaptor.go

@@ -14,6 +14,7 @@ import (
 	"one-api/dto"
 	"one-api/relay/channel"
 	"one-api/relay/channel/ai360"
+	"one-api/relay/channel/jina"
 	"one-api/relay/channel/lingyiwanwu"
 	"one-api/relay/channel/minimax"
 	"one-api/relay/channel/moonshot"
@@ -146,7 +147,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
 }
 
 func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
-	return nil, errors.New("not implemented")
+	return request, nil
 }
 
 func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
@@ -228,6 +229,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 		err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat)
 	case constant.RelayModeImagesGenerations:
 		err, usage = OpenaiTTSHandler(c, resp, info)
+	case constant.RelayModeRerank:
+		err, usage = jina.JinaRerankHandler(c, resp)
 	default:
 		if info.IsStream {
 			err, usage = OaiStreamHandler(c, resp, info)