Просмотр исходного кода

feat: Improve embedding request handling and support across channels

- Update EmbeddingRequest DTO to support more flexible input types
- Add input parsing method to handle various input formats
- Implement ConvertEmbeddingRequest for multiple channel adaptors
- Remove relayMode parameter from EmbeddingHelper
- Add input validation for embedding requests
- Simplify embedding request conversion for different channels
[email protected] 10 месяцев назад
Родитель
Сommit
f5e3063f33

+ 1 - 7
controller/relay.go

@@ -34,7 +34,7 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
 	case relayconstant.RelayModeRerank:
 		err = relay.RerankHelper(c, relayMode)
 	case relayconstant.RelayModeEmbeddings:
-		err = relay.EmbeddingHelper(c,relayMode)
+		err = relay.EmbeddingHelper(c)
 	default:
 		err = relay.TextHelper(c)
 	}
@@ -57,11 +57,6 @@ func Relay(c *gin.Context) {
 	originalModel := c.GetString("original_model")
 	var openaiErr *dto.OpenAIErrorWithStatusCode
 
-	//获取request body 并输出到日志
-	requestBody, _ := common.GetRequestBody(c)
-	common.LogInfo(c, fmt.Sprintf("relayMode: %d ,request body: %s",relayMode, string(requestBody)))
-	
-
 	for i := 0; i <= common.RetryTimes; i++ {
 		channel, err := getChannel(c, group, originalModel, i)
 		if err != nil {
@@ -161,7 +156,6 @@ func WssRelay(c *gin.Context) {
 }
 
 func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
-	common.LogInfo(c, fmt.Sprintf("relayMode: %d ,channel Id : %s",relayMode, string(channel.Id)))
 	addUsedChannel(c, channel.Id)
 	requestBody, _ := common.GetRequestBody(c)
 	c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))

+ 32 - 5
dto/embedding.go

@@ -12,8 +12,35 @@ type EmbeddingOptions struct {
 }
 
 type EmbeddingRequest struct {
-	Model string   `json:"model"`
-	Input []string `json:"input"`
+	Model            string   `json:"model"`
+	Input            any      `json:"input"`
+	EncodingFormat   string   `json:"encoding_format,omitempty"`
+	Dimensions       int      `json:"dimensions,omitempty"`
+	User             string   `json:"user,omitempty"`
+	Seed             float64  `json:"seed,omitempty"`
+	Temperature      *float64 `json:"temperature,omitempty"`
+	TopP             float64  `json:"top_p,omitempty"`
+	FrequencyPenalty float64  `json:"frequency_penalty,omitempty"`
+	PresencePenalty  float64  `json:"presence_penalty,omitempty"`
+}
+
+func (r EmbeddingRequest) ParseInput() []string {
+	if r.Input == nil {
+		return nil
+	}
+	var input []string
+	switch r.Input.(type) {
+	case string:
+		input = []string{r.Input.(string)}
+	case []any:
+		input = make([]string, 0, len(r.Input.([]any)))
+		for _, item := range r.Input.([]any) {
+			if str, ok := item.(string); ok {
+				input = append(input, str)
+			}
+		}
+	}
+	return input
 }
 
 type EmbeddingResponseItem struct {
@@ -23,8 +50,8 @@ type EmbeddingResponseItem struct {
 }
 
 type EmbeddingResponse struct {
-	Object string                        `json:"object"`
+	Object string                  `json:"object"`
 	Data   []EmbeddingResponseItem `json:"data"`
-	Model  string                        `json:"model"`
+	Model  string                  `json:"model"`
 	Usage  `json:"usage"`
-}
+}

+ 1 - 5
relay/channel/ali/adaptor.go

@@ -49,9 +49,6 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
 		return nil, errors.New("request is nil")
 	}
 	switch info.RelayMode {
-	case constant.RelayModeEmbeddings:
-		baiduEmbeddingRequest := embeddingRequestOpenAI2Ali(*request)
-		return baiduEmbeddingRequest, nil
 	default:
 		aliReq := requestOpenAI2Ali(*request)
 		return aliReq, nil
@@ -68,8 +65,7 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 }
 
 func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
-	//TODO implement me
-	return nil, errors.New("not implemented")
+	return embeddingRequestOpenAI2Ali(request), nil
 }
 
 func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {

+ 5 - 2
relay/channel/ali/text.go

@@ -25,9 +25,12 @@ func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIReque
 	return &request
 }
 
-func embeddingRequestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliEmbeddingRequest {
+func embeddingRequestOpenAI2Ali(request dto.EmbeddingRequest) *AliEmbeddingRequest {
+	if request.Model == "" {
+		request.Model = "text-embedding-v1"
+	}
 	return &AliEmbeddingRequest{
-		Model: "text-embedding-v1",
+		Model: request.Model,
 		Input: struct {
 			Texts []string `json:"texts"`
 		}{

+ 2 - 5
relay/channel/baidu/adaptor.go

@@ -109,9 +109,6 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
 		return nil, errors.New("request is nil")
 	}
 	switch info.RelayMode {
-	case constant.RelayModeEmbeddings:
-		baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(*request)
-		return baiduEmbeddingRequest, nil
 	default:
 		baiduRequest := requestOpenAI2Baidu(*request)
 		return baiduRequest, nil
@@ -123,8 +120,8 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 }
 
 func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
-	//TODO implement me
-	return nil, errors.New("not implemented")
+	baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(request)
+	return baiduEmbeddingRequest, nil
 }
 
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {

+ 1 - 1
relay/channel/baidu/relay-baidu.go

@@ -87,7 +87,7 @@ func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *dto.Cha
 	return &response
 }
 
-func embeddingRequestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduEmbeddingRequest {
+func embeddingRequestOpenAI2Baidu(request dto.EmbeddingRequest) *BaiduEmbeddingRequest {
 	return &BaiduEmbeddingRequest{
 		Input: request.ParseInput(),
 	}

+ 1 - 3
relay/channel/cloudflare/adaptor.go

@@ -57,11 +57,9 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 }
 
 func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
-	//TODO implement me
-	return nil, errors.New("not implemented")
+	return request, nil
 }
 
-
 func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
 	// 添加文件字段
 	file, _, err := c.Request.FormFile("file")

+ 1 - 3
relay/channel/jina/adaptor.go

@@ -56,11 +56,9 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 }
 
 func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
-	//TODO implement me
-	return nil, errors.New("not implemented")
+	return request, nil
 }
 
-
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 	if info.RelayMode == constant.RelayModeRerank {
 		err, usage = jinaRerankHandler(c, resp)

+ 2 - 9
relay/channel/ollama/adaptor.go

@@ -46,12 +46,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re
 	if request == nil {
 		return nil, errors.New("request is nil")
 	}
-	switch info.RelayMode {
-	case relayconstant.RelayModeEmbeddings:
-		return requestOpenAI2Embeddings(*request), nil
-	default:
-		return requestOpenAI2Ollama(*request), nil
-	}
+	return requestOpenAI2Ollama(*request), nil
 }
 
 func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
@@ -59,11 +54,9 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 }
 
 func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
-	//TODO implement me
-	return nil, errors.New("not implemented")
+	return requestOpenAI2Embeddings(request), nil
 }
 
-
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 	return channel.DoApiRequest(a, c, info, requestBody)
 }

+ 6 - 6
relay/channel/ollama/relay-ollama.go

@@ -42,7 +42,7 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
 	}
 }
 
-func requestOpenAI2Embeddings(request dto.GeneralOpenAIRequest) *OllamaEmbeddingRequest {
+func requestOpenAI2Embeddings(request dto.EmbeddingRequest) *OllamaEmbeddingRequest {
 	return &OllamaEmbeddingRequest{
 		Model: request.Model,
 		Input: request.ParseInput(),
@@ -123,9 +123,9 @@ func ollamaEmbeddingHandler(c *gin.Context, resp *http.Response, promptTokens in
 }
 
 func flattenEmbeddings(embeddings [][]float64) []float64 {
-flattened := []float64{}
-for _, row := range embeddings {
-	flattened = append(flattened, row...)
+	flattened := []float64{}
+	for _, row := range embeddings {
+		flattened = append(flattened, row...)
+	}
+	return flattened
 }
-return flattened
-}

+ 1 - 2
relay/channel/openai/adaptor.go

@@ -150,8 +150,7 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 }
 
 func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
-	//TODO implement me
-	return nil, errors.New("not implemented")
+	return request, nil
 }
 
 func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {

+ 1 - 3
relay/channel/siliconflow/adaptor.go

@@ -59,11 +59,9 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt
 }
 
 func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
-	//TODO implement me
-	return nil, errors.New("not implemented")
+	return request, nil
 }
 
-
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 	switch info.RelayMode {
 	case constant.RelayModeRerank:

+ 22 - 12
relay/relay_embedding.go

@@ -19,7 +19,20 @@ func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int {
 	return token
 }
 
-func EmbeddingHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWithStatusCode) {
+func validateEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, embeddingRequest dto.EmbeddingRequest) error {
+	if embeddingRequest.Input == nil {
+		return fmt.Errorf("input is empty")
+	}
+	if info.RelayMode == relayconstant.RelayModeModerations && embeddingRequest.Model == "" {
+		embeddingRequest.Model = "omni-moderation-latest"
+	}
+	if info.RelayMode == relayconstant.RelayModeEmbeddings && embeddingRequest.Model == "" {
+		embeddingRequest.Model = c.Param("model")
+	}
+	return nil
+}
+
+func EmbeddingHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
 	relayInfo := relaycommon.GenRelayInfo(c)
 
 	var embeddingRequest *dto.EmbeddingRequest
@@ -28,15 +41,12 @@ func EmbeddingHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorW
 		common.LogError(c, fmt.Sprintf("getAndValidateTextRequest failed: %s", err.Error()))
 		return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest)
 	}
-	if relayMode == relayconstant.RelayModeModerations && embeddingRequest.Model == "" {
-		embeddingRequest.Model = "m3e-base"
-	}
-	if relayMode == relayconstant.RelayModeEmbeddings && embeddingRequest.Model == "" {
-		embeddingRequest.Model = c.Param("model")
-	}
-	if embeddingRequest.Input == nil || len(embeddingRequest.Input) == 0 {
-		return service.OpenAIErrorWrapperLocal(fmt.Errorf("input is empty"), "invalid_input", http.StatusBadRequest)
+
+	err = validateEmbeddingRequest(c, relayInfo, *embeddingRequest)
+	if err != nil {
+		return service.OpenAIErrorWrapperLocal(err, "invalid_embedding_request", http.StatusBadRequest)
 	}
+
 	// map model name
 	modelMapping := c.GetString("model_mapping")
 	//isModelMapped := false
@@ -89,8 +99,8 @@ func EmbeddingHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorW
 	}
 	adaptor.Init(relayInfo)
 
-	convertedRequest, err := adaptor.ConvertEmbeddingRequest(c,relayInfo,*embeddingRequest)
-	
+	convertedRequest, err := adaptor.ConvertEmbeddingRequest(c, relayInfo, *embeddingRequest)
+
 	if err != nil {
 		return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
 	}
@@ -100,7 +110,7 @@ func EmbeddingHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorW
 	}
 	requestBody := bytes.NewBuffer(jsonData)
 	statusCodeMappingStr := c.GetString("status_code_mapping")
-	resp, err := adaptor.DoRequest(c,relayInfo, requestBody)
+	resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
 	if err != nil {
 		return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
 	}

+ 8 - 1
service/token_counter.go

@@ -4,7 +4,6 @@ import (
 	"encoding/json"
 	"errors"
 	"fmt"
-	"github.com/pkoukk/tiktoken-go"
 	"image"
 	"log"
 	"math"
@@ -14,6 +13,8 @@ import (
 	relaycommon "one-api/relay/common"
 	"strings"
 	"unicode/utf8"
+
+	"github.com/pkoukk/tiktoken-go"
 )
 
 // tokenEncoderMap won't grow after initialization
@@ -323,6 +324,12 @@ func CountTokenInput(input any, model string) (int, error) {
 			text += s
 		}
 		return CountTextToken(text, model)
+	case []interface{}:
+		text := ""
+		for _, item := range v {
+			text += fmt.Sprintf("%v", item)
+		}
+		return CountTextToken(text, model)
 	}
 	return CountTokenInput(fmt.Sprintf("%v", input), model)
 }