Browse Source

feat: support modles deployed by huggingface text-embeddings-inference (#175)

* feat: support modles text-embeddings-inference

* fix: ci lint

* fix: review.

* fix: ci lint.

* fix: review

* fix: rerank response

* feat: impl rerank return documents

* test: add rerank test

---------

Co-authored-by: zijiren233 <[email protected]>
yy 7 months ago
parent
commit
623d22811c

+ 70 - 68
core/model/chtype.go

@@ -12,77 +12,79 @@ func (c ChannelType) String() string {
 }
 }
 
 
 const (
 const (
-	ChannelTypeOpenAI             ChannelType = 1
-	ChannelTypeAzure              ChannelType = 3
-	ChannelTypeGoogleGeminiOpenAI ChannelType = 12
-	ChannelTypeBaiduV2            ChannelType = 13
-	ChannelTypeAnthropic          ChannelType = 14
-	ChannelTypeBaidu              ChannelType = 15
-	ChannelTypeZhipu              ChannelType = 16
-	ChannelTypeAli                ChannelType = 17
-	ChannelTypeXunfei             ChannelType = 18
-	ChannelTypeAI360              ChannelType = 19
-	ChannelTypeOpenRouter         ChannelType = 20
-	ChannelTypeTencent            ChannelType = 23
-	ChannelTypeGoogleGemini       ChannelType = 24
-	ChannelTypeMoonshot           ChannelType = 25
-	ChannelTypeBaichuan           ChannelType = 26
-	ChannelTypeMinimax            ChannelType = 27
-	ChannelTypeMistral            ChannelType = 28
-	ChannelTypeGroq               ChannelType = 29
-	ChannelTypeOllama             ChannelType = 30
-	ChannelTypeLingyiwanwu        ChannelType = 31
-	ChannelTypeStepfun            ChannelType = 32
-	ChannelTypeAWS                ChannelType = 33
-	ChannelTypeCoze               ChannelType = 34
-	ChannelTypeCohere             ChannelType = 35
-	ChannelTypeDeepseek           ChannelType = 36
-	ChannelTypeCloudflare         ChannelType = 37
-	ChannelTypeDoubao             ChannelType = 40
-	ChannelTypeNovita             ChannelType = 41
-	ChannelTypeVertexAI           ChannelType = 42
-	ChannelTypeSiliconflow        ChannelType = 43
-	ChannelTypeDoubaoAudio        ChannelType = 44
-	ChannelTypeXAI                ChannelType = 45
-	ChannelTypeDoc2x              ChannelType = 46
-	ChannelTypeJina               ChannelType = 47
+	ChannelTypeOpenAI                  ChannelType = 1
+	ChannelTypeAzure                   ChannelType = 3
+	ChannelTypeGoogleGeminiOpenAI      ChannelType = 12
+	ChannelTypeBaiduV2                 ChannelType = 13
+	ChannelTypeAnthropic               ChannelType = 14
+	ChannelTypeBaidu                   ChannelType = 15
+	ChannelTypeZhipu                   ChannelType = 16
+	ChannelTypeAli                     ChannelType = 17
+	ChannelTypeXunfei                  ChannelType = 18
+	ChannelTypeAI360                   ChannelType = 19
+	ChannelTypeOpenRouter              ChannelType = 20
+	ChannelTypeTencent                 ChannelType = 23
+	ChannelTypeGoogleGemini            ChannelType = 24
+	ChannelTypeMoonshot                ChannelType = 25
+	ChannelTypeBaichuan                ChannelType = 26
+	ChannelTypeMinimax                 ChannelType = 27
+	ChannelTypeMistral                 ChannelType = 28
+	ChannelTypeGroq                    ChannelType = 29
+	ChannelTypeOllama                  ChannelType = 30
+	ChannelTypeLingyiwanwu             ChannelType = 31
+	ChannelTypeStepfun                 ChannelType = 32
+	ChannelTypeAWS                     ChannelType = 33
+	ChannelTypeCoze                    ChannelType = 34
+	ChannelTypeCohere                  ChannelType = 35
+	ChannelTypeDeepseek                ChannelType = 36
+	ChannelTypeCloudflare              ChannelType = 37
+	ChannelTypeDoubao                  ChannelType = 40
+	ChannelTypeNovita                  ChannelType = 41
+	ChannelTypeVertexAI                ChannelType = 42
+	ChannelTypeSiliconflow             ChannelType = 43
+	ChannelTypeDoubaoAudio             ChannelType = 44
+	ChannelTypeXAI                     ChannelType = 45
+	ChannelTypeDoc2x                   ChannelType = 46
+	ChannelTypeJina                    ChannelType = 47
+	ChannelTypeTextEmbeddingsInference ChannelType = 48
 )
 )
 
 
 var channelTypeNames = map[ChannelType]string{
 var channelTypeNames = map[ChannelType]string{
-	ChannelTypeOpenAI:             "openai",
-	ChannelTypeAzure:              "azure",
-	ChannelTypeGoogleGeminiOpenAI: "google gemini (openai)",
-	ChannelTypeBaiduV2:            "baidu v2",
-	ChannelTypeAnthropic:          "anthropic",
-	ChannelTypeBaidu:              "baidu",
-	ChannelTypeZhipu:              "zhipu",
-	ChannelTypeAli:                "ali",
-	ChannelTypeXunfei:             "xunfei",
-	ChannelTypeAI360:              "ai360",
-	ChannelTypeOpenRouter:         "openrouter",
-	ChannelTypeTencent:            "tencent",
-	ChannelTypeGoogleGemini:       "google gemini",
-	ChannelTypeMoonshot:           "moonshot",
-	ChannelTypeBaichuan:           "baichuan",
-	ChannelTypeMinimax:            "minimax",
-	ChannelTypeMistral:            "mistral",
-	ChannelTypeGroq:               "groq",
-	ChannelTypeOllama:             "ollama",
-	ChannelTypeLingyiwanwu:        "lingyiwanwu",
-	ChannelTypeStepfun:            "stepfun",
-	ChannelTypeAWS:                "aws",
-	ChannelTypeCoze:               "coze",
-	ChannelTypeCohere:             "Cohere",
-	ChannelTypeDeepseek:           "deepseek",
-	ChannelTypeCloudflare:         "cloudflare",
-	ChannelTypeDoubao:             "doubao",
-	ChannelTypeNovita:             "novita",
-	ChannelTypeVertexAI:           "vertexai",
-	ChannelTypeSiliconflow:        "siliconflow",
-	ChannelTypeDoubaoAudio:        "doubao audio",
-	ChannelTypeXAI:                "xai",
-	ChannelTypeDoc2x:              "doc2x",
-	ChannelTypeJina:               "jina",
+	ChannelTypeOpenAI:                  "openai",
+	ChannelTypeAzure:                   "azure",
+	ChannelTypeGoogleGeminiOpenAI:      "google gemini (openai)",
+	ChannelTypeBaiduV2:                 "baidu v2",
+	ChannelTypeAnthropic:               "anthropic",
+	ChannelTypeBaidu:                   "baidu",
+	ChannelTypeZhipu:                   "zhipu",
+	ChannelTypeAli:                     "ali",
+	ChannelTypeXunfei:                  "xunfei",
+	ChannelTypeAI360:                   "ai360",
+	ChannelTypeOpenRouter:              "openrouter",
+	ChannelTypeTencent:                 "tencent",
+	ChannelTypeGoogleGemini:            "google gemini",
+	ChannelTypeMoonshot:                "moonshot",
+	ChannelTypeBaichuan:                "baichuan",
+	ChannelTypeMinimax:                 "minimax",
+	ChannelTypeMistral:                 "mistral",
+	ChannelTypeGroq:                    "groq",
+	ChannelTypeOllama:                  "ollama",
+	ChannelTypeLingyiwanwu:             "lingyiwanwu",
+	ChannelTypeStepfun:                 "stepfun",
+	ChannelTypeAWS:                     "aws",
+	ChannelTypeCoze:                    "coze",
+	ChannelTypeCohere:                  "Cohere",
+	ChannelTypeDeepseek:                "deepseek",
+	ChannelTypeCloudflare:              "cloudflare",
+	ChannelTypeDoubao:                  "doubao",
+	ChannelTypeNovita:                  "novita",
+	ChannelTypeVertexAI:                "vertexai",
+	ChannelTypeSiliconflow:             "siliconflow",
+	ChannelTypeDoubaoAudio:             "doubao audio",
+	ChannelTypeXAI:                     "xai",
+	ChannelTypeDoc2x:                   "doc2x",
+	ChannelTypeJina:                    "jina",
+	ChannelTypeTextEmbeddingsInference: "huggingface text-embeddings-inference",
 }
 }
 
 
 func AllChannelTypes() []ChannelType {
 func AllChannelTypes() []ChannelType {

+ 73 - 0
core/relay/adaptor/text-embeddings-inference/adaptor.go

@@ -0,0 +1,73 @@
+package textembeddingsinference
+
+import (
+	"fmt"
+	"io"
+	"net/http"
+
+	"github.com/gin-gonic/gin"
+	"github.com/labring/aiproxy/core/model"
+	"github.com/labring/aiproxy/core/relay/adaptor/openai"
+	"github.com/labring/aiproxy/core/relay/meta"
+	"github.com/labring/aiproxy/core/relay/mode"
+	relaymodel "github.com/labring/aiproxy/core/relay/model"
+	"github.com/labring/aiproxy/core/relay/utils"
+)
+
+// text-embeddings-inference adaptor supports rerank and embeddings models deployed by https://github.com/huggingface/text-embeddings-inference
+type Adaptor struct{}
+
+// base url for text-embeddings-inference, fake
+const baseURL = "https://api.text-embeddings.net"
+
+func (a *Adaptor) GetBaseURL() string {
+	return baseURL
+}
+
+func (a *Adaptor) GetModelList() []*model.ModelConfig {
+	return ModelList
+}
+
+func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
+	switch meta.Mode {
+	case mode.Rerank:
+		return meta.Channel.BaseURL + "/rerank", nil
+	case mode.Embeddings:
+		return meta.Channel.BaseURL + "/v1/embeddings", nil
+	default:
+		return "", fmt.Errorf("unsupported mode: %s", meta.Mode)
+	}
+}
+
+// text-embeddings-inference api see https://huggingface.github.io/text-embeddings-inference/#/Text%20Embeddings%20Inference/rerank
+
+func (a *Adaptor) SetupRequestHeader(meta *meta.Meta, _ *gin.Context, req *http.Request) error {
+	req.Header.Set("Authorization", "Bearer "+meta.Channel.Key)
+	return nil
+}
+
+func (a *Adaptor) ConvertRequest(meta *meta.Meta, req *http.Request) (string, http.Header, io.Reader, error) {
+	switch meta.Mode {
+	case mode.Rerank:
+		return ConvertRerankRequest(meta, req)
+	case mode.Embeddings:
+		return openai.ConvertRequest(meta, req)
+	default:
+		return "", nil, nil, fmt.Errorf("unsupported mode: %s", meta.Mode)
+	}
+}
+
+func (a *Adaptor) DoRequest(_ *meta.Meta, _ *gin.Context, req *http.Request) (*http.Response, error) {
+	return utils.DoRequest(req)
+}
+
+func (a *Adaptor) DoResponse(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model.Usage, *relaymodel.ErrorWithStatusCode) {
+	switch meta.Mode {
+	case mode.Rerank:
+		return RerankHandler(meta, c, resp)
+	case mode.Embeddings:
+		return EmbeddingsHandler(meta, c, resp)
+	default:
+		return nil, openai.ErrorWrapperWithMessage(fmt.Sprintf("unsupported mode: %s", meta.Mode), "unsupported_mode", http.StatusBadRequest)
+	}
+}

+ 22 - 0
core/relay/adaptor/text-embeddings-inference/constants.go

@@ -0,0 +1,22 @@
+package textembeddingsinference
+
+import (
+	"github.com/labring/aiproxy/core/model"
+	"github.com/labring/aiproxy/core/relay/mode"
+)
+
+// maybe we should use a list of models from https://github.com/huggingface/text-embeddings-inference?tab=readme-ov-file#supported-models
+var ModelList = []*model.ModelConfig{
+	{
+		Model: "bge-reranker-v2-m3",
+		Type:  mode.Rerank,
+		Owner: model.ModelOwnerBAAI,
+		Price: model.Price{
+			InputPrice:  0.015,
+			OutputPrice: 0.015,
+		},
+		Config: model.NewModelConfig(
+			model.WithModelConfigMaxContextTokens(32768),
+		),
+	},
+}

+ 18 - 0
core/relay/adaptor/text-embeddings-inference/embeddings.go

@@ -0,0 +1,18 @@
+package textembeddingsinference
+
+import (
+	"net/http"
+
+	"github.com/gin-gonic/gin"
+	"github.com/labring/aiproxy/core/model"
+	"github.com/labring/aiproxy/core/relay/adaptor/openai"
+	"github.com/labring/aiproxy/core/relay/meta"
+	relaymodel "github.com/labring/aiproxy/core/relay/model"
+)
+
+func EmbeddingsHandler(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model.Usage, *relaymodel.ErrorWithStatusCode) {
+	if resp.StatusCode != http.StatusOK {
+		return nil, EmbeddingsErrorHanlder(resp)
+	}
+	return openai.DoResponse(meta, c, resp)
+}

+ 57 - 0
core/relay/adaptor/text-embeddings-inference/error.go

@@ -0,0 +1,57 @@
+package textembeddingsinference
+
+import (
+	"net/http"
+
+	"github.com/bytedance/sonic"
+	"github.com/labring/aiproxy/core/relay/adaptor/openai"
+	"github.com/labring/aiproxy/core/relay/model"
+)
+
+type RerankErrorResponse struct {
+	Error     string `json:"error"`
+	ErrorType string `json:"error_type"`
+}
+
+func RerankErrorHanlder(resp *http.Response) *model.ErrorWithStatusCode {
+	defer resp.Body.Close()
+
+	errResp := RerankErrorResponse{}
+	err := sonic.ConfigDefault.NewDecoder(resp.Body).Decode(&errResp)
+	if err != nil {
+		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
+	}
+
+	return &model.ErrorWithStatusCode{
+		Error: model.Error{
+			Message: errResp.Error,
+			Type:    errResp.ErrorType,
+			Code:    resp.StatusCode,
+		},
+		StatusCode: resp.StatusCode,
+	}
+}
+
+type EmbeddingsErrorResponse struct {
+	Type    string `json:"type"`
+	Message string `json:"message"`
+}
+
+func EmbeddingsErrorHanlder(resp *http.Response) *model.ErrorWithStatusCode {
+	defer resp.Body.Close()
+
+	errResp := EmbeddingsErrorResponse{}
+	err := sonic.ConfigDefault.NewDecoder(resp.Body).Decode(&errResp)
+	if err != nil {
+		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
+	}
+
+	return &model.ErrorWithStatusCode{
+		Error: model.Error{
+			Message: errResp.Message,
+			Type:    errResp.Type,
+			Code:    resp.StatusCode,
+		},
+		StatusCode: resp.StatusCode,
+	}
+}

+ 143 - 0
core/relay/adaptor/text-embeddings-inference/rerank.go

@@ -0,0 +1,143 @@
+package textembeddingsinference
+
+import (
+	"bytes"
+	"errors"
+	"fmt"
+	"io"
+	"net/http"
+
+	"github.com/bytedance/sonic"
+	"github.com/bytedance/sonic/ast"
+	"github.com/gin-gonic/gin"
+	"github.com/labring/aiproxy/core/common"
+	"github.com/labring/aiproxy/core/middleware"
+	"github.com/labring/aiproxy/core/model"
+	"github.com/labring/aiproxy/core/relay/adaptor/openai"
+	"github.com/labring/aiproxy/core/relay/meta"
+	relaymodel "github.com/labring/aiproxy/core/relay/model"
+)
+
+func ConvertRerankRequest(meta *meta.Meta, req *http.Request) (string, http.Header, io.Reader, error) {
+	node, err := common.UnmarshalBody2Node(req)
+	if err != nil {
+		return "", nil, nil, fmt.Errorf("failed to parse request body: %w", err)
+	}
+
+	// Set the actual model in the request
+	_, err = node.Set("model", ast.NewString(meta.ActualModel))
+	if err != nil {
+		return "", nil, nil, err
+	}
+
+	// Get the documents array and rename it to texts
+	documentsNode := node.Get("documents")
+	if !documentsNode.Exists() {
+		return "", nil, nil, errors.New("documents field not found")
+	}
+
+	// Set the texts field with the documents value
+	_, err = node.Set("texts", *documentsNode)
+	if err != nil {
+		return "", nil, nil, fmt.Errorf("failed to set texts field: %w", err)
+	}
+
+	// Remove the documents field
+	_, err = node.Unset("documents")
+	if err != nil {
+		return "", nil, nil, fmt.Errorf("failed to remove documents field: %w", err)
+	}
+
+	returnDocumentsNode := node.Get("return_documents")
+	if returnDocumentsNode.Exists() {
+		returnDocuments, err := returnDocumentsNode.Bool()
+		if err != nil {
+			return "", nil, nil, fmt.Errorf("failed to unmarshal return_documents field: %w", err)
+		}
+		_, err = node.Unset("return_documents")
+		if err != nil {
+			return "", nil, nil, fmt.Errorf("failed to remove return_documents field: %w", err)
+		}
+		_, err = node.Set("return_text", ast.NewBool(returnDocuments))
+		if err != nil {
+			return "", nil, nil, fmt.Errorf("failed to set return_text field: %w", err)
+		}
+	}
+
+	// Convert back to JSON
+	jsonData, err := node.MarshalJSON()
+	if err != nil {
+		return "", nil, nil, fmt.Errorf("failed to marshal request: %w", err)
+	}
+
+	return http.MethodPost, nil, bytes.NewReader(jsonData), nil
+}
+
+type RerankResponse []RerankResponseItem
+
+type RerankResponseItem struct {
+	Index int     `json:"index"`
+	Score float64 `json:"score"`
+	Text  string  `json:"text,omitempty"`
+}
+
+func (rri *RerankResponseItem) ToRerankModel() *relaymodel.RerankResult {
+	var document *relaymodel.Document
+	if rri.Text != "" {
+		document = &relaymodel.Document{
+			Text: rri.Text,
+		}
+	}
+	return &relaymodel.RerankResult{
+		Index:          rri.Index,
+		RelevanceScore: rri.Score,
+		Document:       document,
+	}
+}
+
+func RerankHandler(meta *meta.Meta, c *gin.Context, resp *http.Response) (*model.Usage, *relaymodel.ErrorWithStatusCode) {
+	if resp.StatusCode != http.StatusOK {
+		return nil, RerankErrorHanlder(resp)
+	}
+
+	defer resp.Body.Close()
+
+	log := middleware.GetLogger(c)
+
+	respSlice := RerankResponse{}
+	err := sonic.ConfigDefault.NewDecoder(resp.Body).Decode(&respSlice)
+	if err != nil {
+		return nil, openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
+	}
+
+	usage := &model.Usage{
+		InputTokens: meta.RequestUsage.InputTokens,
+		TotalTokens: meta.RequestUsage.InputTokens,
+	}
+
+	results := make([]*relaymodel.RerankResult, len(respSlice))
+	for i, v := range respSlice {
+		results[i] = v.ToRerankModel()
+	}
+
+	rerankResp := relaymodel.RerankResponse{
+		Meta: relaymodel.RerankMeta{
+			Tokens: &relaymodel.RerankMetaTokens{
+				InputTokens: int64(usage.InputTokens),
+			},
+		},
+		Results: results,
+		ID:      meta.RequestID,
+	}
+
+	jsonResponse, err := sonic.Marshal(rerankResp)
+	if err != nil {
+		return usage, openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
+	}
+
+	_, err = c.Writer.Write(jsonResponse)
+	if err != nil {
+		log.Warnf("write response body failed: %v", err)
+	}
+	return usage, nil
+}

+ 127 - 0
core/relay/adaptor/text-embeddings-inference/rerank_test.go

@@ -0,0 +1,127 @@
+package textembeddingsinference_test
+
+import (
+	"bytes"
+	"io"
+	"net/http"
+	"testing"
+
+	"github.com/bytedance/sonic"
+	textembeddingsinference "github.com/labring/aiproxy/core/relay/adaptor/text-embeddings-inference"
+	"github.com/labring/aiproxy/core/relay/meta"
+	"github.com/stretchr/testify/assert"
+)
+
+func TestConvertRerankRequest(t *testing.T) {
+	// Test successful conversion
+	t.Run("successful conversion", func(t *testing.T) {
+		// Create mock request body
+		requestBody := map[string]interface{}{
+			"model": "original-model",
+			"documents": []string{
+				"This is document 1",
+				"This is document 2",
+			},
+			"query": "Find relevant documents",
+		}
+
+		jsonBody, err := sonic.Marshal(requestBody)
+		assert.NoError(t, err)
+
+		// Create mock HTTP request
+		req, err := http.NewRequest(http.MethodPost, "/rerank", bytes.NewReader(jsonBody))
+		assert.NoError(t, err)
+		req.Header.Set("Content-Type", "application/json")
+
+		// Create mock meta
+		testMeta := &meta.Meta{
+			ActualModel: "text-embeddings-model",
+		}
+
+		// Call the function under test
+		method, _, bodyReader, err := textembeddingsinference.ConvertRerankRequest(testMeta, req)
+
+		// Assert no error
+		assert.NoError(t, err)
+
+		// Assert method
+		assert.Equal(t, http.MethodPost, method)
+
+		// Read the transformed body
+		bodyBytes, err := io.ReadAll(bodyReader)
+		assert.NoError(t, err)
+
+		// Parse the body back to verify the transformation
+		var transformedBody map[string]interface{}
+		err = sonic.Unmarshal(bodyBytes, &transformedBody)
+		assert.NoError(t, err)
+
+		// Verify the model was replaced
+		assert.Equal(t, "text-embeddings-model", transformedBody["model"])
+
+		// Verify documents was renamed to texts
+		assert.NotContains(t, transformedBody, "documents")
+
+		// Verify texts contains the documents content
+		textsArray, ok := transformedBody["texts"].([]interface{})
+		assert.True(t, ok, "texts should be an array")
+		assert.Len(t, textsArray, 2)
+		assert.Equal(t, "This is document 1", textsArray[0])
+		assert.Equal(t, "This is document 2", textsArray[1])
+
+		// Verify query remains unchanged
+		assert.Equal(t, "Find relevant documents", transformedBody["query"])
+	})
+
+	// Test missing documents field
+	t.Run("missing documents field", func(t *testing.T) {
+		// Create mock request body without documents
+		requestBody := map[string]interface{}{
+			"model": "original-model",
+			"query": "Find relevant documents",
+		}
+
+		jsonBody, err := sonic.Marshal(requestBody)
+		assert.NoError(t, err)
+
+		// Create mock HTTP request
+		req, err := http.NewRequest(http.MethodPost, "/rerank", bytes.NewReader(jsonBody))
+		assert.NoError(t, err)
+		req.Header.Set("Content-Type", "application/json")
+
+		// Create mock meta
+		testMeta := &meta.Meta{
+			ActualModel: "text-embeddings-model",
+		}
+
+		// Call the function under test
+		_, _, _, err = textembeddingsinference.ConvertRerankRequest(testMeta, req)
+
+		// Assert error for missing documents
+		assert.Error(t, err)
+		assert.Contains(t, err.Error(), "documents field not found")
+	})
+
+	// Test invalid JSON
+	t.Run("invalid json", func(t *testing.T) {
+		// Create invalid JSON body
+		invalidJSON := []byte(`{"model": "test", "documents": [`)
+
+		// Create mock HTTP request
+		req, err := http.NewRequest(http.MethodPost, "/rerank", bytes.NewReader(invalidJSON))
+		assert.NoError(t, err)
+		req.Header.Set("Content-Type", "application/json")
+
+		// Create mock meta
+		testMeta := &meta.Meta{
+			ActualModel: "text-embeddings-model",
+		}
+
+		// Call the function under test
+		_, _, _, err = textembeddingsinference.ConvertRerankRequest(testMeta, req)
+
+		// Assert error for invalid JSON
+		assert.Error(t, err)
+		assert.Contains(t, err.Error(), "failed to parse request body")
+	})
+}

+ 36 - 34
core/relay/channeltype/define.go

@@ -33,6 +33,7 @@ import (
 	"github.com/labring/aiproxy/core/relay/adaptor/siliconflow"
 	"github.com/labring/aiproxy/core/relay/adaptor/siliconflow"
 	"github.com/labring/aiproxy/core/relay/adaptor/stepfun"
 	"github.com/labring/aiproxy/core/relay/adaptor/stepfun"
 	"github.com/labring/aiproxy/core/relay/adaptor/tencent"
 	"github.com/labring/aiproxy/core/relay/adaptor/tencent"
+	text_embeddings_inference "github.com/labring/aiproxy/core/relay/adaptor/text-embeddings-inference"
 	"github.com/labring/aiproxy/core/relay/adaptor/vertexai"
 	"github.com/labring/aiproxy/core/relay/adaptor/vertexai"
 	"github.com/labring/aiproxy/core/relay/adaptor/xai"
 	"github.com/labring/aiproxy/core/relay/adaptor/xai"
 	"github.com/labring/aiproxy/core/relay/adaptor/xunfei"
 	"github.com/labring/aiproxy/core/relay/adaptor/xunfei"
@@ -40,40 +41,41 @@ import (
 )
 )
 
 
 var ChannelAdaptor = map[model.ChannelType]adaptor.Adaptor{
 var ChannelAdaptor = map[model.ChannelType]adaptor.Adaptor{
-	model.ChannelTypeOpenAI:             &openai.Adaptor{},
-	model.ChannelTypeAzure:              &azure.Adaptor{},
-	model.ChannelTypeGoogleGeminiOpenAI: &geminiopenai.Adaptor{},
-	model.ChannelTypeBaiduV2:            &baiduv2.Adaptor{},
-	model.ChannelTypeAnthropic:          &anthropic.Adaptor{},
-	model.ChannelTypeBaidu:              &baidu.Adaptor{},
-	model.ChannelTypeZhipu:              &zhipu.Adaptor{},
-	model.ChannelTypeAli:                &ali.Adaptor{},
-	model.ChannelTypeXunfei:             &xunfei.Adaptor{},
-	model.ChannelTypeAI360:              &ai360.Adaptor{},
-	model.ChannelTypeOpenRouter:         &openrouter.Adaptor{},
-	model.ChannelTypeTencent:            &tencent.Adaptor{},
-	model.ChannelTypeGoogleGemini:       &gemini.Adaptor{},
-	model.ChannelTypeMoonshot:           &moonshot.Adaptor{},
-	model.ChannelTypeBaichuan:           &baichuan.Adaptor{},
-	model.ChannelTypeMinimax:            &minimax.Adaptor{},
-	model.ChannelTypeMistral:            &mistral.Adaptor{},
-	model.ChannelTypeGroq:               &groq.Adaptor{},
-	model.ChannelTypeOllama:             &ollama.Adaptor{},
-	model.ChannelTypeLingyiwanwu:        &lingyiwanwu.Adaptor{},
-	model.ChannelTypeStepfun:            &stepfun.Adaptor{},
-	model.ChannelTypeAWS:                &aws.Adaptor{},
-	model.ChannelTypeCoze:               &coze.Adaptor{},
-	model.ChannelTypeCohere:             &cohere.Adaptor{},
-	model.ChannelTypeDeepseek:           &deepseek.Adaptor{},
-	model.ChannelTypeCloudflare:         &cloudflare.Adaptor{},
-	model.ChannelTypeDoubao:             &doubao.Adaptor{},
-	model.ChannelTypeNovita:             &novita.Adaptor{},
-	model.ChannelTypeVertexAI:           &vertexai.Adaptor{},
-	model.ChannelTypeSiliconflow:        &siliconflow.Adaptor{},
-	model.ChannelTypeDoubaoAudio:        &doubaoaudio.Adaptor{},
-	model.ChannelTypeXAI:                &xai.Adaptor{},
-	model.ChannelTypeDoc2x:              &doc2x.Adaptor{},
-	model.ChannelTypeJina:               &jina.Adaptor{},
+	model.ChannelTypeOpenAI:                  &openai.Adaptor{},
+	model.ChannelTypeAzure:                   &azure.Adaptor{},
+	model.ChannelTypeGoogleGeminiOpenAI:      &geminiopenai.Adaptor{},
+	model.ChannelTypeBaiduV2:                 &baiduv2.Adaptor{},
+	model.ChannelTypeAnthropic:               &anthropic.Adaptor{},
+	model.ChannelTypeBaidu:                   &baidu.Adaptor{},
+	model.ChannelTypeZhipu:                   &zhipu.Adaptor{},
+	model.ChannelTypeAli:                     &ali.Adaptor{},
+	model.ChannelTypeXunfei:                  &xunfei.Adaptor{},
+	model.ChannelTypeAI360:                   &ai360.Adaptor{},
+	model.ChannelTypeOpenRouter:              &openrouter.Adaptor{},
+	model.ChannelTypeTencent:                 &tencent.Adaptor{},
+	model.ChannelTypeGoogleGemini:            &gemini.Adaptor{},
+	model.ChannelTypeMoonshot:                &moonshot.Adaptor{},
+	model.ChannelTypeBaichuan:                &baichuan.Adaptor{},
+	model.ChannelTypeMinimax:                 &minimax.Adaptor{},
+	model.ChannelTypeMistral:                 &mistral.Adaptor{},
+	model.ChannelTypeGroq:                    &groq.Adaptor{},
+	model.ChannelTypeOllama:                  &ollama.Adaptor{},
+	model.ChannelTypeLingyiwanwu:             &lingyiwanwu.Adaptor{},
+	model.ChannelTypeStepfun:                 &stepfun.Adaptor{},
+	model.ChannelTypeAWS:                     &aws.Adaptor{},
+	model.ChannelTypeCoze:                    &coze.Adaptor{},
+	model.ChannelTypeCohere:                  &cohere.Adaptor{},
+	model.ChannelTypeDeepseek:                &deepseek.Adaptor{},
+	model.ChannelTypeCloudflare:              &cloudflare.Adaptor{},
+	model.ChannelTypeDoubao:                  &doubao.Adaptor{},
+	model.ChannelTypeNovita:                  &novita.Adaptor{},
+	model.ChannelTypeVertexAI:                &vertexai.Adaptor{},
+	model.ChannelTypeSiliconflow:             &siliconflow.Adaptor{},
+	model.ChannelTypeDoubaoAudio:             &doubaoaudio.Adaptor{},
+	model.ChannelTypeXAI:                     &xai.Adaptor{},
+	model.ChannelTypeDoc2x:                   &doc2x.Adaptor{},
+	model.ChannelTypeJina:                    &jina.Adaptor{},
+	model.ChannelTypeTextEmbeddingsInference: &text_embeddings_inference.Adaptor{},
 }
 }
 
 
 func GetAdaptor(channelType model.ChannelType) (adaptor.Adaptor, bool) {
 func GetAdaptor(channelType model.ChannelType) (adaptor.Adaptor, bool) {