Browse Source

feat: ali rerank

RedwindA 7 months ago
parent
commit
49e77fb3df

+ 5 - 1
relay/channel/ali/adaptor.go

@@ -31,6 +31,8 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 	switch info.RelayMode {
 	case constant.RelayModeEmbeddings:
 		fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", info.BaseUrl)
+	case constant.RelayModeRerank:
+		fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.BaseUrl)
 	case constant.RelayModeImagesGenerations:
 		fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.BaseUrl)
 	case constant.RelayModeCompletions:
@@ -76,7 +78,7 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
 }
 
 func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
-	return nil, errors.New("not implemented")
+	return ConvertRerankRequest(request), nil
 }
 
 func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
@@ -103,6 +105,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 		err, usage = aliImageHandler(c, resp, info)
 	case constant.RelayModeEmbeddings:
 		err, usage = aliEmbeddingHandler(c, resp)
+	case constant.RelayModeRerank:
+		err, usage = RerankHandler(c, resp, info)
 	default:
 		if info.IsStream {
 			err, usage = openai.OaiStreamHandler(c, resp, info)

+ 1 - 0
relay/channel/ali/constants.go

@@ -8,6 +8,7 @@ var ModelList = []string{
 	"qwq-32b",
 	"qwen3-235b-a22b",
 	"text-embedding-v1",
+	"gte-rerank-v2",
 }
 
 var ChannelName = "ali"

+ 27 - 0
relay/channel/ali/dto.go

@@ -1,5 +1,7 @@
 package ali
 
+import "one-api/dto"
+
 type AliMessage struct {
 	Content string `json:"content"`
 	Role    string `json:"role"`
@@ -97,3 +99,28 @@ type AliImageRequest struct {
 	} `json:"parameters,omitempty"`
 	ResponseFormat string `json:"response_format,omitempty"`
 }
+
+type AliRerankParameters struct {
+	TopN            *int  `json:"top_n,omitempty"`
+	ReturnDocuments *bool `json:"return_documents,omitempty"`
+}
+
+type AliRerankInput struct {
+	Query     string `json:"query"`
+	Documents []any  `json:"documents"`
+}
+
+type AliRerankRequest struct {
+	Model      string              `json:"model"`
+	Input      AliRerankInput      `json:"input"`
+	Parameters AliRerankParameters `json:"parameters,omitempty"`
+}
+
+type AliRerankResponse struct {
+	Output struct {
+		Results []dto.RerankResponseResult `json:"results"`
+	} `json:"output"`
+	Usage     AliUsage `json:"usage"`
+	RequestId string   `json:"request_id"`
+	AliError
+}

+ 83 - 0
relay/channel/ali/rerank.go

@@ -0,0 +1,83 @@
+package ali
+
+import (
+	"encoding/json"
+	"io"
+	"net/http"
+	"one-api/dto"
+	relaycommon "one-api/relay/common"
+	"one-api/service"
+
+	"github.com/gin-gonic/gin"
+)
+
+func ConvertRerankRequest(request dto.RerankRequest) *AliRerankRequest {
+	returnDocuments := request.ReturnDocuments
+	if returnDocuments == nil {
+		t := true
+		returnDocuments = &t
+	}
+	return &AliRerankRequest{
+		Model: request.Model,
+		Input: AliRerankInput{
+			Query:     request.Query,
+			Documents: request.Documents,
+		},
+		Parameters: AliRerankParameters{
+			TopN:            &request.TopN,
+			ReturnDocuments: returnDocuments,
+		},
+	}
+}
+
+func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+	}
+	err = resp.Body.Close()
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+	}
+
+	var aliResponse AliRerankResponse
+	err = json.Unmarshal(responseBody, &aliResponse)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+	}
+
+	if aliResponse.Code != "" {
+		return &dto.OpenAIErrorWithStatusCode{
+			Error: dto.OpenAIError{
+				Message: aliResponse.Message,
+				Type:    aliResponse.Code,
+				Param:   aliResponse.RequestId,
+				Code:    aliResponse.Code,
+			},
+			StatusCode: resp.StatusCode,
+		}, nil
+	}
+
+	usage := dto.Usage{
+		PromptTokens:     aliResponse.Usage.TotalTokens,
+		CompletionTokens: 0,
+		TotalTokens:      aliResponse.Usage.TotalTokens,
+	}
+	rerankResponse := dto.RerankResponse{
+		Results: aliResponse.Output.Results,
+		Usage:   usage,
+	}
+
+	jsonResponse, err := json.Marshal(rerankResponse)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+	}
+	c.Writer.Header().Set("Content-Type", "application/json")
+	c.Writer.WriteHeader(resp.StatusCode)
+	_, err = c.Writer.Write(jsonResponse)
+	if err != nil {
+		return service.OpenAIErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError), nil
+	}
+
+	return nil, &usage
+}