Просмотр исходного кода

Merge pull request #1107 from QuantumNous/gemini-relay

Gemini 格式
IcedTangerine 6 месяцев назад
Родитель
Сommit
b0cbf71a1c

+ 2 - 0
controller/relay.go

@@ -40,6 +40,8 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
 		err = relay.EmbeddingHelper(c)
 	case relayconstant.RelayModeResponses:
 		err = relay.ResponsesHelper(c)
+	case relayconstant.RelayModeGemini:
+		err = relay.GeminiHelper(c)
 	default:
 		err = relay.TextHelper(c)
 	}

+ 10 - 2
middleware/auth.go

@@ -1,13 +1,14 @@
 package middleware
 
 import (
-	"github.com/gin-contrib/sessions"
-	"github.com/gin-gonic/gin"
 	"net/http"
 	"one-api/common"
 	"one-api/model"
 	"strconv"
 	"strings"
+
+	"github.com/gin-contrib/sessions"
+	"github.com/gin-gonic/gin"
 )
 
 func validUserInfo(username string, role int) bool {
@@ -182,6 +183,13 @@ func TokenAuth() func(c *gin.Context) {
 				c.Request.Header.Set("Authorization", "Bearer "+key)
 			}
 		}
+		// gemini api 从query中获取key
+		if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") {
+			skKey := c.Query("key")
+			if skKey != "" {
+				c.Request.Header.Set("Authorization", "Bearer "+skKey)
+			}
+		}
 		key := c.Request.Header.Get("Authorization")
 		parts := make([]string, 0)
 		key = strings.TrimPrefix(key, "Bearer ")

+ 36 - 0
middleware/distributor.go

@@ -162,6 +162,14 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
 		}
 		c.Set("platform", string(constant.TaskPlatformSuno))
 		c.Set("relay_mode", relayMode)
+	} else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") {
+		// Gemini API 路径处理: /v1beta/models/gemini-2.0-flash:generateContent
+		relayMode := relayconstant.RelayModeGemini
+		modelName := extractModelNameFromGeminiPath(c.Request.URL.Path)
+		if modelName != "" {
+			modelRequest.Model = modelName
+		}
+		c.Set("relay_mode", relayMode)
 	} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") && !strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") {
 		err = common.UnmarshalBodyReusable(c, &modelRequest)
 	}
@@ -244,3 +252,31 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
 		c.Set("bot_id", channel.Other)
 	}
 }
+
+// extractModelNameFromGeminiPath 从 Gemini API URL 路径中提取模型名
+// 输入格式: /v1beta/models/gemini-2.0-flash:generateContent
+// 输出: gemini-2.0-flash
+func extractModelNameFromGeminiPath(path string) string {
+	// 查找 "/models/" 的位置
+	modelsPrefix := "/models/"
+	modelsIndex := strings.Index(path, modelsPrefix)
+	if modelsIndex == -1 {
+		return ""
+	}
+
+	// 从 "/models/" 之后开始提取
+	startIndex := modelsIndex + len(modelsPrefix)
+	if startIndex >= len(path) {
+		return ""
+	}
+
+	// 查找 ":" 的位置,模型名在 ":" 之前
+	colonIndex := strings.Index(path[startIndex:], ":")
+	if colonIndex == -1 {
+		// 如果没有找到 ":",返回从 "/models/" 到路径结尾的部分
+		return path[startIndex:]
+	}
+
+	// 返回模型名部分
+	return path[startIndex : startIndex+colonIndex]
+}

+ 9 - 0
relay/channel/gemini/adaptor.go

@@ -10,6 +10,7 @@ import (
 	"one-api/dto"
 	"one-api/relay/channel"
 	relaycommon "one-api/relay/common"
+	"one-api/relay/constant"
 	"one-api/service"
 	"one-api/setting/model_setting"
 	"strings"
@@ -165,6 +166,14 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 }
 
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
+	if info.RelayMode == constant.RelayModeGemini {
+		if info.IsStream {
+			return GeminiTextGenerationStreamHandler(c, resp, info)
+		} else {
+			return GeminiTextGenerationHandler(c, resp, info)
+		}
+	}
+
 	if strings.HasPrefix(info.UpstreamModelName, "imagen") {
 		return GeminiImageHandler(c, resp, info)
 	}

+ 128 - 0
relay/channel/gemini/relay-gemini-native.go

@@ -0,0 +1,128 @@
+package gemini
+
+import (
+	"encoding/json"
+	"io"
+	"net/http"
+	"one-api/common"
+	"one-api/dto"
+	relaycommon "one-api/relay/common"
+	"one-api/relay/helper"
+	"one-api/service"
+
+	"github.com/gin-gonic/gin"
+)
+
+func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) {
+	// 读取响应体
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return nil, service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
+	}
+	err = resp.Body.Close()
+	if err != nil {
+		return nil, service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
+	}
+
+	if common.DebugEnabled {
+		println(string(responseBody))
+	}
+
+	// 解析为 Gemini 原生响应格式
+	var geminiResponse GeminiChatResponse
+	err = common.DecodeJson(responseBody, &geminiResponse)
+	if err != nil {
+		return nil, service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
+	}
+
+	// 检查是否有候选响应
+	if len(geminiResponse.Candidates) == 0 {
+		return nil, &dto.OpenAIErrorWithStatusCode{
+			Error: dto.OpenAIError{
+				Message: "No candidates returned",
+				Type:    "server_error",
+				Param:   "",
+				Code:    500,
+			},
+			StatusCode: resp.StatusCode,
+		}
+	}
+
+	// 计算使用量(基于 UsageMetadata)
+	usage := dto.Usage{
+		PromptTokens:     geminiResponse.UsageMetadata.PromptTokenCount,
+		CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
+		TotalTokens:      geminiResponse.UsageMetadata.TotalTokenCount,
+	}
+
+	// 直接返回 Gemini 原生格式的 JSON 响应
+	jsonResponse, err := json.Marshal(geminiResponse)
+	if err != nil {
+		return nil, service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
+	}
+
+	// 设置响应头并写入响应
+	c.Writer.Header().Set("Content-Type", "application/json")
+	c.Writer.WriteHeader(resp.StatusCode)
+	_, err = c.Writer.Write(jsonResponse)
+	if err != nil {
+		return nil, service.OpenAIErrorWrapper(err, "write_response_failed", http.StatusInternalServerError)
+	}
+
+	return &usage, nil
+}
+
+func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) {
+	var usage = &dto.Usage{}
+	var imageCount int
+
+	helper.SetEventStreamHeaders(c)
+
+	helper.StreamScannerHandler(c, resp, info, func(data string) bool {
+		var geminiResponse GeminiChatResponse
+		err := common.DecodeJsonStr(data, &geminiResponse)
+		if err != nil {
+			common.LogError(c, "error unmarshalling stream response: "+err.Error())
+			return false
+		}
+
+		// 统计图片数量
+		for _, candidate := range geminiResponse.Candidates {
+			for _, part := range candidate.Content.Parts {
+				if part.InlineData != nil && part.InlineData.MimeType != "" {
+					imageCount++
+				}
+			}
+		}
+
+		// 更新使用量统计
+		if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
+			usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
+			usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
+			usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
+		}
+
+		// 直接发送 GeminiChatResponse 响应
+		err = helper.ObjectData(c, geminiResponse)
+		if err != nil {
+			common.LogError(c, err.Error())
+		}
+
+		return true
+	})
+
+	if imageCount != 0 {
+		if usage.CompletionTokens == 0 {
+			usage.CompletionTokens = imageCount * 258
+		}
+	}
+
+	// 计算最终使用量
+	usage.PromptTokensDetails.TextTokens = usage.PromptTokens
+	usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
+
+	// 结束流式响应
+	helper.Done(c)
+
+	return usage, nil
+}

+ 11 - 2
relay/channel/vertex/adaptor.go

@@ -12,6 +12,7 @@ import (
 	"one-api/relay/channel/gemini"
 	"one-api/relay/channel/openai"
 	relaycommon "one-api/relay/common"
+	"one-api/relay/constant"
 	"one-api/setting/model_setting"
 	"strings"
 
@@ -201,7 +202,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 		case RequestModeClaude:
 			err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
 		case RequestModeGemini:
-			err, usage = gemini.GeminiChatStreamHandler(c, resp, info)
+			if info.RelayMode == constant.RelayModeGemini {
+				usage, err = gemini.GeminiTextGenerationStreamHandler(c, resp, info)
+			} else {
+				err, usage = gemini.GeminiChatStreamHandler(c, resp, info)
+			}
 		case RequestModeLlama:
 			err, usage = openai.OaiStreamHandler(c, resp, info)
 		}
@@ -210,7 +215,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
 		case RequestModeClaude:
 			err, usage = claude.ClaudeHandler(c, resp, claude.RequestModeMessage, info)
 		case RequestModeGemini:
-			err, usage = gemini.GeminiChatHandler(c, resp, info)
+			if info.RelayMode == constant.RelayModeGemini {
+				usage, err = gemini.GeminiTextGenerationHandler(c, resp, info)
+			} else {
+				err, usage = gemini.GeminiChatHandler(c, resp, info)
+			}
 		case RequestModeLlama:
 			err, usage = openai.OpenaiHandler(c, resp, info)
 		}

+ 4 - 0
relay/constant/relay_mode.go

@@ -43,6 +43,8 @@ const (
 	RelayModeResponses
 
 	RelayModeRealtime
+
+	RelayModeGemini
 )
 
 func Path2RelayMode(path string) int {
@@ -75,6 +77,8 @@ func Path2RelayMode(path string) int {
 		relayMode = RelayModeRerank
 	} else if strings.HasPrefix(path, "/v1/realtime") {
 		relayMode = RelayModeRealtime
+	} else if strings.HasPrefix(path, "/v1beta/models") {
+		relayMode = RelayModeGemini
 	}
 	return relayMode
 }

+ 157 - 0
relay/relay-gemini.go

@@ -0,0 +1,157 @@
+package relay
+
+import (
+	"bytes"
+	"encoding/json"
+	"errors"
+	"fmt"
+	"net/http"
+	"one-api/common"
+	"one-api/dto"
+	"one-api/relay/channel/gemini"
+	relaycommon "one-api/relay/common"
+	"one-api/relay/helper"
+	"one-api/service"
+	"one-api/setting"
+	"strings"
+
+	"github.com/gin-gonic/gin"
+)
+
+func getAndValidateGeminiRequest(c *gin.Context) (*gemini.GeminiChatRequest, error) {
+	request := &gemini.GeminiChatRequest{}
+	err := common.UnmarshalBodyReusable(c, request)
+	if err != nil {
+		return nil, err
+	}
+	if len(request.Contents) == 0 {
+		return nil, errors.New("contents is required")
+	}
+	return request, nil
+}
+
+// 流模式
+// /v1beta/models/gemini-2.0-flash:streamGenerateContent?alt=sse&key=xxx
+func checkGeminiStreamMode(c *gin.Context, relayInfo *relaycommon.RelayInfo) {
+	if c.Query("alt") == "sse" {
+		relayInfo.IsStream = true
+	}
+
+	// if strings.Contains(c.Request.URL.Path, "streamGenerateContent") {
+	// 	relayInfo.IsStream = true
+	// }
+}
+
+func checkGeminiInputSensitive(textRequest *gemini.GeminiChatRequest) ([]string, error) {
+	var inputTexts []string
+	for _, content := range textRequest.Contents {
+		for _, part := range content.Parts {
+			if part.Text != "" {
+				inputTexts = append(inputTexts, part.Text)
+			}
+		}
+	}
+	if len(inputTexts) == 0 {
+		return nil, nil
+	}
+
+	sensitiveWords, err := service.CheckSensitiveInput(inputTexts)
+	return sensitiveWords, err
+}
+
+func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) (int, error) {
+	// 计算输入 token 数量
+	var inputTexts []string
+	for _, content := range req.Contents {
+		for _, part := range content.Parts {
+			if part.Text != "" {
+				inputTexts = append(inputTexts, part.Text)
+			}
+		}
+	}
+
+	inputText := strings.Join(inputTexts, "\n")
+	inputTokens, err := service.CountTokenInput(inputText, info.UpstreamModelName)
+	info.PromptTokens = inputTokens
+	return inputTokens, err
+}
+
+func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
+	req, err := getAndValidateGeminiRequest(c)
+	if err != nil {
+		common.LogError(c, fmt.Sprintf("getAndValidateGeminiRequest error: %s", err.Error()))
+		return service.OpenAIErrorWrapperLocal(err, "invalid_gemini_request", http.StatusBadRequest)
+	}
+
+	relayInfo := relaycommon.GenRelayInfo(c)
+
+	// 检查 Gemini 流式模式
+	checkGeminiStreamMode(c, relayInfo)
+
+	if setting.ShouldCheckPromptSensitive() {
+		sensitiveWords, err := checkGeminiInputSensitive(req)
+		if err != nil {
+			common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", ")))
+			return service.OpenAIErrorWrapperLocal(err, "check_request_sensitive_error", http.StatusBadRequest)
+		}
+	}
+
+	// model mapped 模型映射
+	err = helper.ModelMappedHelper(c, relayInfo)
+	if err != nil {
+		return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest)
+	}
+
+	if value, exists := c.Get("prompt_tokens"); exists {
+		promptTokens := value.(int)
+		relayInfo.SetPromptTokens(promptTokens)
+	} else {
+		promptTokens, err := getGeminiInputTokens(req, relayInfo)
+		if err != nil {
+			return service.OpenAIErrorWrapperLocal(err, "count_input_tokens_error", http.StatusBadRequest)
+		}
+		c.Set("prompt_tokens", promptTokens)
+	}
+
+	priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.GenerationConfig.MaxOutputTokens))
+	if err != nil {
+		return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
+	}
+
+	// pre consume quota
+	preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
+	if openaiErr != nil {
+		return openaiErr
+	}
+	defer func() {
+		if openaiErr != nil {
+			returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
+		}
+	}()
+
+	adaptor := GetAdaptor(relayInfo.ApiType)
+	if adaptor == nil {
+		return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
+	}
+
+	adaptor.Init(relayInfo)
+
+	requestBody, err := json.Marshal(req)
+	if err != nil {
+		return service.OpenAIErrorWrapperLocal(err, "marshal_text_request_failed", http.StatusInternalServerError)
+	}
+
+	resp, err := adaptor.DoRequest(c, relayInfo, bytes.NewReader(requestBody))
+	if err != nil {
+		common.LogError(c, "Do gemini request failed: "+err.Error())
+		return service.OpenAIErrorWrapperLocal(err, "do_request_failed", http.StatusInternalServerError)
+	}
+
+	usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), relayInfo)
+	if openaiErr != nil {
+		return openaiErr
+	}
+
+	postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
+	return nil
+}

+ 8 - 0
router/relay-router.go

@@ -79,6 +79,14 @@ func SetRelayRouter(router *gin.Engine) {
 		relaySunoRouter.GET("/fetch/:id", controller.RelayTask)
 	}
 
+	relayGeminiRouter := router.Group("/v1beta")
+	relayGeminiRouter.Use(middleware.TokenAuth())
+	relayGeminiRouter.Use(middleware.ModelRequestRateLimit())
+	relayGeminiRouter.Use(middleware.Distribute())
+	{
+		// Gemini API 路径格式: /v1beta/models/{model_name}:{action}
+		relayGeminiRouter.POST("/models/*path", controller.Relay)
+	}
 }
 
 func registerMjRouterGroup(relayMjRouter *gin.RouterGroup) {