Browse Source

feat: jina embed support (#146)

zijiren 8 months ago
parent
commit
1baed236f6

+ 1 - 2
core/relay/adaptor/baidu/adaptor.go

@@ -92,8 +92,7 @@ func (a *Adaptor) SetupRequestHeader(meta *meta.Meta, _ *gin.Context, req *http.
 func (a *Adaptor) ConvertRequest(meta *meta.Meta, req *http.Request) (string, http.Header, io.Reader, error) {
 	switch meta.Mode {
 	case mode.Embeddings:
-		meta.Set(openai.MetaEmbeddingsPatchInputToSlices, true)
-		return openai.ConvertRequest(meta, req)
+		return openai.ConvertEmbeddingsRequest(meta, req, true)
 	case mode.Rerank:
 		return openai.ConvertRequest(meta, req)
 	case mode.ImagesGenerations:

+ 11 - 2
core/relay/adaptor/jina/adaptor.go

@@ -1,7 +1,7 @@
 package jina
 
 import (
-	"fmt"
+	"io"
 	"net/http"
 
 	"github.com/gin-gonic/gin"
@@ -22,12 +22,21 @@ func (a *Adaptor) GetBaseURL() string {
 	return baseURL
 }
 
+func (a *Adaptor) ConvertRequest(meta *meta.Meta, req *http.Request) (string, http.Header, io.Reader, error) {
+	switch meta.Mode {
+	case mode.Embeddings:
+		return ConvertEmbeddingsRequest(meta, req)
+	default:
+		return a.Adaptor.ConvertRequest(meta, req)
+	}
+}
+
 func (a *Adaptor) DoResponse(meta *meta.Meta, c *gin.Context, resp *http.Response) (usage *model.Usage, err *relaymodel.ErrorWithStatusCode) {
 	switch meta.Mode {
 	case mode.Rerank:
 		return RerankHandler(meta, c, resp)
 	default:
-		return nil, openai.ErrorWrapperWithMessage(fmt.Sprintf("unsupported mode: %s", meta.Mode), "unsupported_mode", http.StatusBadRequest)
+		return a.Adaptor.DoResponse(meta, c, resp)
 	}
 }
 

+ 35 - 0
core/relay/adaptor/jina/embeddings.go

@@ -0,0 +1,35 @@
+package jina
+
+import (
+	"bytes"
+	"io"
+	"net/http"
+
+	"github.com/bytedance/sonic"
+	"github.com/labring/aiproxy/core/common"
+	"github.com/labring/aiproxy/core/relay/meta"
+)
+
+//nolint:gocritic
+func ConvertEmbeddingsRequest(meta *meta.Meta, req *http.Request) (string, http.Header, io.Reader, error) {
+	reqMap := make(map[string]any)
+	err := common.UnmarshalBodyReusable(req, &reqMap)
+	if err != nil {
+		return "", nil, nil, err
+	}
+
+	reqMap["model"] = meta.ActualModel
+
+	switch v := reqMap["input"].(type) {
+	case string:
+		reqMap["input"] = []string{v}
+	}
+
+	delete(reqMap, "encoding_format")
+
+	jsonData, err := sonic.Marshal(reqMap)
+	if err != nil {
+		return "", nil, nil, err
+	}
+	return http.MethodPost, nil, bytes.NewReader(jsonData), nil
+}

+ 2 - 3
core/relay/adaptor/openai/adaptor.go

@@ -75,10 +75,9 @@ func ConvertRequest(meta *meta.Meta, req *http.Request) (string, http.Header, io
 	}
 	switch meta.Mode {
 	case mode.Moderations:
-		meta.Set(MetaEmbeddingsPatchInputToSlices, true)
-		return ConvertEmbeddingsRequest(meta, req)
+		return ConvertEmbeddingsRequest(meta, req, true)
 	case mode.Embeddings, mode.Completions:
-		return ConvertEmbeddingsRequest(meta, req)
+		return ConvertEmbeddingsRequest(meta, req, false)
 	case mode.ChatCompletions:
 		return ConvertTextRequest(meta, req, false)
 	case mode.ImagesGenerations:

+ 2 - 4
core/relay/adaptor/openai/embeddings.go

@@ -10,10 +10,8 @@ import (
 	"github.com/labring/aiproxy/core/relay/meta"
 )
 
-const MetaEmbeddingsPatchInputToSlices = "embeddings_input_to_slices"
-
 //nolint:gocritic
-func ConvertEmbeddingsRequest(meta *meta.Meta, req *http.Request) (string, http.Header, io.Reader, error) {
+func ConvertEmbeddingsRequest(meta *meta.Meta, req *http.Request, inputToSlices bool) (string, http.Header, io.Reader, error) {
 	reqMap := make(map[string]any)
 	err := common.UnmarshalBodyReusable(req, &reqMap)
 	if err != nil {
@@ -22,7 +20,7 @@ func ConvertEmbeddingsRequest(meta *meta.Meta, req *http.Request) (string, http.
 
 	reqMap["model"] = meta.ActualModel
 
-	if meta.GetBool(MetaEmbeddingsPatchInputToSlices) {
+	if inputToSlices {
 		switch v := reqMap["input"].(type) {
 		case string:
 			reqMap["input"] = []string{v}