Browse Source

gemini stream

creamlike1024 7 months ago
parent
commit
d90e4bef63
4 changed files with 93 additions and 92 deletions
  1. 0 69
      dto/gemini.go
  2. 5 2
      relay/channel/gemini/adaptor.go
  3. 66 15
      relay/channel/gemini/relay-gemini-native.go
  4. 22 6
      relay/relay-gemini.go

+ 0 - 69
dto/gemini.go

@@ -1,69 +0,0 @@
-package dto
-
-import "encoding/json"
-
-type GeminiPart struct {
-	Text string `json:"text"`
-}
-
-type GeminiContent struct {
-	Parts []GeminiPart `json:"parts"`
-	Role  string       `json:"role"`
-}
-
-type GeminiCandidate struct {
-	Content      GeminiContent `json:"content"`
-	FinishReason string        `json:"finishReason"`
-	AvgLogprobs  float64       `json:"avgLogprobs"`
-}
-
-type GeminiTokenDetails struct {
-	Modality   string `json:"modality"`
-	TokenCount int    `json:"tokenCount"`
-}
-
-type GeminiUsageMetadata struct {
-	PromptTokenCount        int                  `json:"promptTokenCount"`
-	CandidatesTokenCount    int                  `json:"candidatesTokenCount"`
-	TotalTokenCount         int                  `json:"totalTokenCount"`
-	PromptTokensDetails     []GeminiTokenDetails `json:"promptTokensDetails"`
-	CandidatesTokensDetails []GeminiTokenDetails `json:"candidatesTokensDetails"`
-}
-
-type GeminiTextGenerationResponse struct {
-	Candidates    []GeminiCandidate   `json:"candidates"`
-	UsageMetadata GeminiUsageMetadata `json:"usageMetadata"`
-	ModelVersion  string              `json:"modelVersion"`
-	ResponseID    string              `json:"responseId"`
-}
-
-type GeminiGenerationConfig struct {
-	StopSequences              []string         `json:"stopSequences,omitempty"`
-	ResponseMimeType           string           `json:"responseMimeType,omitempty"`
-	ResponseSchema             *json.RawMessage `json:"responseSchema,omitempty"`
-	ResponseModalities         *json.RawMessage `json:"responseModalities,omitempty"`
-	CandidateCount             int              `json:"candidateCount,omitempty"`
-	MaxOutputTokens            int              `json:"maxOutputTokens,omitempty"`
-	Temperature                float64          `json:"temperature,omitempty"`
-	TopP                       float64          `json:"topP,omitempty"`
-	TopK                       int              `json:"topK,omitempty"`
-	Seed                       int              `json:"seed,omitempty"`
-	PresencePenalty            float64          `json:"presencePenalty,omitempty"`
-	FrequencyPenalty           float64          `json:"frequencyPenalty,omitempty"`
-	ResponseLogprobs           bool             `json:"responseLogprobs,omitempty"`
-	LogProbs                   int              `json:"logProbs,omitempty"`
-	EnableEnhancedCivicAnswers bool             `json:"enableEnhancedCivicAnswers,omitempty"`
-	SpeechConfig               *json.RawMessage `json:"speechConfig,omitempty"`
-	ThinkingConfig             *json.RawMessage `json:"thinkingConfig,omitempty"`
-	MediaResolution            *json.RawMessage `json:"mediaResolution,omitempty"`
-}
-
-type GeminiTextGenerationRequest struct {
-	Contents          []GeminiContent        `json:"contents"`
-	Tools             *json.RawMessage       `json:"tools,omitempty"`
-	ToolConfig        *json.RawMessage       `json:"toolConfig,omitempty"`
-	SafetySettings    *json.RawMessage       `json:"safetySettings,omitempty"`
-	SystemInstruction *json.RawMessage       `json:"systemInstruction,omitempty"`
-	GenerationConfig  GeminiGenerationConfig `json:"generationConfig,omitempty"`
-	CachedContent     *json.RawMessage       `json:"cachedContent,omitempty"`
-}

+ 5 - 2
relay/channel/gemini/adaptor.go

@@ -167,8 +167,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 
 func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
 	if info.RelayMode == constant.RelayModeGemini {
-		err, usage = GeminiTextGenerationHandler(c, resp, info)
-		return usage, err
+		if info.IsStream {
+			return GeminiTextGenerationStreamHandler(c, resp, info)
+		} else {
+			return GeminiTextGenerationHandler(c, resp, info)
+		}
 	}
 
 	if strings.HasPrefix(info.UpstreamModelName, "imagen") {

+ 66 - 15
relay/channel/gemini/relay-gemini-native.go

@@ -7,20 +7,21 @@ import (
 	"one-api/common"
 	"one-api/dto"
 	relaycommon "one-api/relay/common"
+	"one-api/relay/helper"
 	"one-api/service"
 
 	"github.com/gin-gonic/gin"
 )
 
-func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
+func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) {
 	// 读取响应体
 	responseBody, err := io.ReadAll(resp.Body)
 	if err != nil {
-		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+		return nil, service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
 	}
 	err = resp.Body.Close()
 	if err != nil {
-		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+		return nil, service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
 	}
 
 	if common.DebugEnabled {
@@ -28,15 +29,15 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela
 	}
 
 	// 解析为 Gemini 原生响应格式
-	var geminiResponse dto.GeminiTextGenerationResponse
+	var geminiResponse GeminiChatResponse
 	err = common.DecodeJson(responseBody, &geminiResponse)
 	if err != nil {
-		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+		return nil, service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
 	}
 
 	// 检查是否有候选响应
 	if len(geminiResponse.Candidates) == 0 {
-		return &dto.OpenAIErrorWithStatusCode{
+		return nil, &dto.OpenAIErrorWithStatusCode{
 			Error: dto.OpenAIError{
 				Message: "No candidates returned",
 				Type:    "server_error",
@@ -44,7 +45,7 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela
 				Code:    500,
 			},
 			StatusCode: resp.StatusCode,
-		}, nil
+		}
 	}
 
 	// 计算使用量(基于 UsageMetadata)
@@ -54,15 +55,10 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela
 		TotalTokens:      geminiResponse.UsageMetadata.TotalTokenCount,
 	}
 
-	// 设置模型版本
-	if geminiResponse.ModelVersion == "" {
-		geminiResponse.ModelVersion = info.UpstreamModelName
-	}
-
 	// 直接返回 Gemini 原生格式的 JSON 响应
 	jsonResponse, err := json.Marshal(geminiResponse)
 	if err != nil {
-		return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+		return nil, service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
 	}
 
 	// 设置响应头并写入响应
@@ -70,8 +66,63 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela
 	c.Writer.WriteHeader(resp.StatusCode)
 	_, err = c.Writer.Write(jsonResponse)
 	if err != nil {
-		return service.OpenAIErrorWrapper(err, "write_response_failed", http.StatusInternalServerError), nil
+		return nil, service.OpenAIErrorWrapper(err, "write_response_failed", http.StatusInternalServerError)
 	}
 
-	return nil, &usage
+	return &usage, nil
+}
+
+func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) {
+	var usage = &dto.Usage{}
+	var imageCount int
+
+	helper.SetEventStreamHeaders(c)
+
+	helper.StreamScannerHandler(c, resp, info, func(data string) bool {
+		var geminiResponse GeminiChatResponse
+		err := common.DecodeJsonStr(data, &geminiResponse)
+		if err != nil {
+			common.LogError(c, "error unmarshalling stream response: "+err.Error())
+			return false
+		}
+
+		// 统计图片数量
+		for _, candidate := range geminiResponse.Candidates {
+			for _, part := range candidate.Content.Parts {
+				if part.InlineData != nil && part.InlineData.MimeType != "" {
+					imageCount++
+				}
+			}
+		}
+
+		// 更新使用量统计
+		if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
+			usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
+			usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
+			usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
+		}
+
+		// 直接发送 GeminiChatResponse 响应
+		err = helper.ObjectData(c, geminiResponse)
+		if err != nil {
+			common.LogError(c, err.Error())
+		}
+
+		return true
+	})
+
+	if imageCount != 0 {
+		if usage.CompletionTokens == 0 {
+			usage.CompletionTokens = imageCount * 258
+		}
+	}
+
+	// 计算最终使用量
+	usage.PromptTokensDetails.TextTokens = usage.PromptTokens
+	usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
+
+	// 结束流式响应
+	helper.Done(c)
+
+	return usage, nil
 }

+ 22 - 6
relay/relay-gemini.go

@@ -8,6 +8,7 @@ import (
 	"net/http"
 	"one-api/common"
 	"one-api/dto"
+	"one-api/relay/channel/gemini"
 	relaycommon "one-api/relay/common"
 	"one-api/relay/helper"
 	"one-api/service"
@@ -17,8 +18,8 @@ import (
 	"github.com/gin-gonic/gin"
 )
 
-func getAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiTextGenerationRequest, error) {
-	request := &dto.GeminiTextGenerationRequest{}
+func getAndValidateGeminiRequest(c *gin.Context) (*gemini.GeminiChatRequest, error) {
+	request := &gemini.GeminiChatRequest{}
 	err := common.UnmarshalBodyReusable(c, request)
 	if err != nil {
 		return nil, err
@@ -29,7 +30,19 @@ func getAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiTextGenerationReque
 	return request, nil
 }
 
-func checkGeminiInputSensitive(textRequest *dto.GeminiTextGenerationRequest, info *relaycommon.RelayInfo) ([]string, error) {
+// 流模式
+// /v1beta/models/gemini-2.0-flash:streamGenerateContent?alt=sse&key=xxx
+func checkGeminiStreamMode(c *gin.Context, relayInfo *relaycommon.RelayInfo) {
+	if c.Query("alt") == "sse" {
+		relayInfo.IsStream = true
+	}
+
+	// if strings.Contains(c.Request.URL.Path, "streamGenerateContent") {
+	// 	relayInfo.IsStream = true
+	// }
+}
+
+func checkGeminiInputSensitive(textRequest *gemini.GeminiChatRequest) ([]string, error) {
 	var inputTexts []string
 	for _, content := range textRequest.Contents {
 		for _, part := range content.Parts {
@@ -46,7 +59,7 @@ func checkGeminiInputSensitive(textRequest *dto.GeminiTextGenerationRequest, inf
 	return sensitiveWords, err
 }
 
-func getGeminiInputTokens(req *dto.GeminiTextGenerationRequest, info *relaycommon.RelayInfo) (int, error) {
+func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) (int, error) {
 	// 计算输入 token 数量
 	var inputTexts []string
 	for _, content := range req.Contents {
@@ -72,8 +85,11 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
 
 	relayInfo := relaycommon.GenRelayInfo(c)
 
+	// 检查 Gemini 流式模式
+	checkGeminiStreamMode(c, relayInfo)
+
 	if setting.ShouldCheckPromptSensitive() {
-		sensitiveWords, err := checkGeminiInputSensitive(req, relayInfo)
+		sensitiveWords, err := checkGeminiInputSensitive(req)
 		if err != nil {
 			common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", ")))
 			return service.OpenAIErrorWrapperLocal(err, "check_request_sensitive_error", http.StatusBadRequest)
@@ -97,7 +113,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
 		c.Set("prompt_tokens", promptTokens)
 	}
 
-	priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, req.GenerationConfig.MaxOutputTokens)
+	priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.GenerationConfig.MaxOutputTokens))
 	if err != nil {
 		return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
 	}