فهرست منبع

Merge pull request #1537 from RedwindA/feat/support-native-gemini-embedding

feat: 支持原生Gemini Embedding格式
Calcium-Ion 4 ماه پیش
والد
کامیت
02fd80b703
6فایلهای تغییر یافته به همراه187 افزوده شده و 26 حذف شده
  1. 5 1
      controller/relay.go
  2. 8 3
      dto/gemini.go
  3. 13 14
      relay/channel/gemini/adaptor.go
  4. 38 1
      relay/channel/gemini/relay-gemini-native.go
  5. 8 7
      relay/common/relay_info.go
  6. 115 0
      relay/gemini_handler.go

+ 5 - 1
controller/relay.go

@@ -42,7 +42,11 @@ func relayHandler(c *gin.Context, relayMode int) *types.NewAPIError {
 	case relayconstant.RelayModeResponses:
 		err = relay.ResponsesHelper(c)
 	case relayconstant.RelayModeGemini:
-		err = relay.GeminiHelper(c)
+		if strings.Contains(c.Request.URL.Path, "embed") {
+			err = relay.GeminiEmbeddingHandler(c)
+		} else {
+			err = relay.GeminiHelper(c)
+		}
 	default:
 		err = relay.TextHelper(c)
 	}

+ 8 - 3
dto/gemini.go

@@ -210,6 +210,7 @@ type GeminiImagePrediction struct {
 
 // Embedding related structs
 type GeminiEmbeddingRequest struct {
+	Model                string            `json:"model,omitempty"`
 	Content              GeminiChatContent `json:"content"`
 	TaskType             string            `json:"taskType,omitempty"`
 	Title                string            `json:"title,omitempty"`
@@ -220,10 +221,14 @@ type GeminiBatchEmbeddingRequest struct {
 	Requests []*GeminiEmbeddingRequest `json:"requests"`
 }
 
-type GeminiEmbedding struct {
-	Values []float64 `json:"values"`
+type GeminiEmbeddingResponse struct {
+	Embedding ContentEmbedding `json:"embedding"`
 }
 
 type GeminiBatchEmbeddingResponse struct {
-	Embeddings []*GeminiEmbedding `json:"embeddings"`
+	Embeddings []*ContentEmbedding `json:"embeddings"`
+}
+
+type ContentEmbedding struct {
+	Values []float64 `json:"values"`
 }

+ 13 - 14
relay/channel/gemini/adaptor.go

@@ -114,7 +114,11 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 	if strings.HasPrefix(info.UpstreamModelName, "text-embedding") ||
 		strings.HasPrefix(info.UpstreamModelName, "embedding") ||
 		strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") {
-		return fmt.Sprintf("%s/%s/models/%s:batchEmbedContents", info.BaseUrl, version, info.UpstreamModelName), nil
+		action := "embedContent"
+		if info.IsGeminiBatchEmbedding {
+			action = "batchEmbedContents"
+		}
+		return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
 	}
 
 	action := "generateContent"
@@ -159,6 +163,9 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
 	if len(inputs) == 0 {
 		return nil, errors.New("input is empty")
 	}
+	// We always build a batch-style payload with `requests`, so ensure we call the
+	// batch endpoint upstream to avoid payload/endpoint mismatches.
+	info.IsGeminiBatchEmbedding = true
 	// process all inputs
 	geminiRequests := make([]map[string]interface{}, 0, len(inputs))
 	for _, input := range inputs {
@@ -176,7 +183,7 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
 		// set specific parameters for different models
 		// https://ai.google.dev/api/embeddings?hl=zh-cn#method:-models.embedcontent
 		switch info.UpstreamModelName {
-		case "text-embedding-004","gemini-embedding-exp-03-07","gemini-embedding-001":
+		case "text-embedding-004", "gemini-embedding-exp-03-07", "gemini-embedding-001":
 			// Only newer models introduced after 2024 support OutputDimensionality
 			if request.Dimensions > 0 {
 				geminiRequest["outputDimensionality"] = request.Dimensions
@@ -201,6 +208,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
 	if info.RelayMode == constant.RelayModeGemini {
+		if strings.HasSuffix(info.RequestURLPath, ":embedContent") ||
+			strings.HasSuffix(info.RequestURLPath, ":batchEmbedContents") {
+			return NativeGeminiEmbeddingHandler(c, resp, info)
+		}
 		if info.IsStream {
 			return GeminiTextGenerationStreamHandler(c, info, resp)
 		} else {
@@ -225,18 +236,6 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 		return GeminiChatHandler(c, info, resp)
 	}
 
-	//if usage.(*dto.Usage).CompletionTokenDetails.ReasoningTokens > 100 {
-	//	// 没有请求-thinking的情况下,产生思考token,则按照思考模型计费
-	//	if !strings.HasSuffix(info.OriginModelName, "-thinking") &&
-	//		!strings.HasSuffix(info.OriginModelName, "-nothinking") {
-	//		thinkingModelName := info.OriginModelName + "-thinking"
-	//		if operation_setting.SelfUseModeEnabled || helper.ContainPriceOrRatio(thinkingModelName) {
-	//			info.OriginModelName = thinkingModelName
-	//		}
-	//	}
-	//}
-
-	return nil, types.NewError(errors.New("not implemented"), types.ErrorCodeBadResponseBody)
 }
 
 func (a *Adaptor) GetModelList() []string {

+ 38 - 1
relay/channel/gemini/relay-gemini-native.go

@@ -1,7 +1,6 @@
 package gemini
 
 import (
-	"github.com/pkg/errors"
 	"io"
 	"net/http"
 	"one-api/common"
@@ -12,6 +11,8 @@ import (
 	"one-api/types"
 	"strings"
 
+	"github.com/pkg/errors"
+
 	"github.com/gin-gonic/gin"
 )
 
@@ -63,6 +64,42 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re
 	return &usage, nil
 }
 
+func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
+	defer common.CloseResponseBodyGracefully(resp)
+
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
+	}
+
+	if common.DebugEnabled {
+		println(string(responseBody))
+	}
+
+	usage := &dto.Usage{
+		PromptTokens: info.PromptTokens,
+		TotalTokens:  info.PromptTokens,
+	}
+
+	if info.IsGeminiBatchEmbedding {
+		var geminiResponse dto.GeminiBatchEmbeddingResponse
+		err = common.Unmarshal(responseBody, &geminiResponse)
+		if err != nil {
+			return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
+		}
+	} else {
+		var geminiResponse dto.GeminiEmbeddingResponse
+		err = common.Unmarshal(responseBody, &geminiResponse)
+		if err != nil {
+			return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
+		}
+	}
+
+	common.IOCopyBytesGracefully(c, resp, responseBody)
+
+	return usage, nil
+}
+
 func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
 	var usage = &dto.Usage{}
 	var imageCount int

+ 8 - 7
relay/common/relay_info.go

@@ -74,13 +74,14 @@ type RelayInfo struct {
 	FirstResponseTime    time.Time
 	isFirstResponse      bool
 	//SendLastReasoningResponse bool
-	ApiType           int
-	IsStream          bool
-	IsPlayground      bool
-	UsePrice          bool
-	RelayMode         int
-	UpstreamModelName string
-	OriginModelName   string
+	ApiType                int
+	IsStream               bool
+	IsGeminiBatchEmbedding bool
+	IsPlayground           bool
+	UsePrice               bool
+	RelayMode              int
+	UpstreamModelName      string
+	OriginModelName        string
 	//RecodeModelName      string
 	RequestURLPath       string
 	ApiVersion           string

+ 115 - 0
relay/gemini_handler.go

@@ -264,3 +264,118 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
 	postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
 	return nil
 }
+
+func GeminiEmbeddingHandler(c *gin.Context) (newAPIError *types.NewAPIError) {
+	relayInfo := relaycommon.GenRelayInfoGemini(c)
+
+	isBatch := strings.HasSuffix(c.Request.URL.Path, "batchEmbedContents")
+	relayInfo.IsGeminiBatchEmbedding = isBatch
+
+	var promptTokens int
+	var req any
+	var err error
+	var inputTexts []string
+
+	if isBatch {
+		batchRequest := &dto.GeminiBatchEmbeddingRequest{}
+		err = common.UnmarshalBodyReusable(c, batchRequest)
+		if err != nil {
+			return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
+		}
+		req = batchRequest
+		for _, r := range batchRequest.Requests {
+			for _, part := range r.Content.Parts {
+				if part.Text != "" {
+					inputTexts = append(inputTexts, part.Text)
+				}
+			}
+		}
+	} else {
+		singleRequest := &dto.GeminiEmbeddingRequest{}
+		err = common.UnmarshalBodyReusable(c, singleRequest)
+		if err != nil {
+			return types.NewError(err, types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
+		}
+		req = singleRequest
+		for _, part := range singleRequest.Content.Parts {
+			if part.Text != "" {
+				inputTexts = append(inputTexts, part.Text)
+			}
+		}
+	}
+	promptTokens = service.CountTokenInput(strings.Join(inputTexts, "\n"), relayInfo.UpstreamModelName)
+	relayInfo.SetPromptTokens(promptTokens)
+	c.Set("prompt_tokens", promptTokens)
+
+	err = helper.ModelMappedHelper(c, relayInfo, req)
+	if err != nil {
+		return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
+	}
+
+	priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, 0)
+	if err != nil {
+		return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
+	}
+
+	preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
+	if newAPIError != nil {
+		return newAPIError
+	}
+	defer func() {
+		if newAPIError != nil {
+			returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
+		}
+	}()
+
+	adaptor := GetAdaptor(relayInfo.ApiType)
+	if adaptor == nil {
+		return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
+	}
+	adaptor.Init(relayInfo)
+
+	var requestBody io.Reader
+	jsonData, err := common.Marshal(req)
+	if err != nil {
+		return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry())
+	}
+
+	// apply param override
+	if len(relayInfo.ParamOverride) > 0 {
+		reqMap := make(map[string]interface{})
+		_ = common.Unmarshal(jsonData, &reqMap)
+		for key, value := range relayInfo.ParamOverride {
+			reqMap[key] = value
+		}
+		jsonData, err = common.Marshal(reqMap)
+		if err != nil {
+			return types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid, types.ErrOptionWithSkipRetry())
+		}
+	}
+	requestBody = bytes.NewReader(jsonData)
+
+	resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
+	if err != nil {
+		common.LogError(c, "Do gemini request failed: "+err.Error())
+		return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError)
+	}
+
+	statusCodeMappingStr := c.GetString("status_code_mapping")
+	var httpResp *http.Response
+	if resp != nil {
+		httpResp = resp.(*http.Response)
+		if httpResp.StatusCode != http.StatusOK {
+			newAPIError = service.RelayErrorHandler(httpResp, false)
+			service.ResetStatusCode(newAPIError, statusCodeMappingStr)
+			return newAPIError
+		}
+	}
+
+	usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), relayInfo)
+	if openaiErr != nil {
+		service.ResetStatusCode(openaiErr, statusCodeMappingStr)
+		return openaiErr
+	}
+
+	postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
+	return nil
+}