|
|
@@ -7,20 +7,21 @@ import (
|
|
|
"one-api/common"
|
|
|
"one-api/dto"
|
|
|
relaycommon "one-api/relay/common"
|
|
|
+ "one-api/relay/helper"
|
|
|
"one-api/service"
|
|
|
|
|
|
"github.com/gin-gonic/gin"
|
|
|
)
|
|
|
|
|
|
-func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
|
|
+func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) {
|
|
|
// 读取响应体
|
|
|
responseBody, err := io.ReadAll(resp.Body)
|
|
|
if err != nil {
|
|
|
- return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
|
|
+ return nil, service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
|
|
}
|
|
|
err = resp.Body.Close()
|
|
|
if err != nil {
|
|
|
- return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
|
+ return nil, service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
|
|
}
|
|
|
|
|
|
if common.DebugEnabled {
|
|
|
@@ -28,15 +29,15 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela
|
|
|
}
|
|
|
|
|
|
// 解析为 Gemini 原生响应格式
|
|
|
- var geminiResponse dto.GeminiTextGenerationResponse
|
|
|
+ var geminiResponse GeminiChatResponse
|
|
|
err = common.DecodeJson(responseBody, &geminiResponse)
|
|
|
if err != nil {
|
|
|
- return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
|
|
+ return nil, service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
|
|
}
|
|
|
|
|
|
// 检查是否有候选响应
|
|
|
if len(geminiResponse.Candidates) == 0 {
|
|
|
- return &dto.OpenAIErrorWithStatusCode{
|
|
|
+ return nil, &dto.OpenAIErrorWithStatusCode{
|
|
|
Error: dto.OpenAIError{
|
|
|
Message: "No candidates returned",
|
|
|
Type: "server_error",
|
|
|
@@ -44,7 +45,7 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela
|
|
|
Code: 500,
|
|
|
},
|
|
|
StatusCode: resp.StatusCode,
|
|
|
- }, nil
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
// 计算使用量(基于 UsageMetadata)
|
|
|
@@ -54,15 +55,10 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela
|
|
|
TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
|
|
|
}
|
|
|
|
|
|
- // 设置模型版本
|
|
|
- if geminiResponse.ModelVersion == "" {
|
|
|
- geminiResponse.ModelVersion = info.UpstreamModelName
|
|
|
- }
|
|
|
-
|
|
|
// 直接返回 Gemini 原生格式的 JSON 响应
|
|
|
jsonResponse, err := json.Marshal(geminiResponse)
|
|
|
if err != nil {
|
|
|
- return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
|
|
+ return nil, service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
|
|
|
}
|
|
|
|
|
|
// 设置响应头并写入响应
|
|
|
@@ -70,8 +66,63 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela
|
|
|
c.Writer.WriteHeader(resp.StatusCode)
|
|
|
_, err = c.Writer.Write(jsonResponse)
|
|
|
if err != nil {
|
|
|
- return service.OpenAIErrorWrapper(err, "write_response_failed", http.StatusInternalServerError), nil
|
|
|
+ return nil, service.OpenAIErrorWrapper(err, "write_response_failed", http.StatusInternalServerError)
|
|
|
}
|
|
|
|
|
|
- return nil, &usage
|
|
|
+ return &usage, nil
|
|
|
+}
|
|
|
+
|
|
|
+func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) {
|
|
|
+ var usage = &dto.Usage{}
|
|
|
+ var imageCount int
|
|
|
+
|
|
|
+ helper.SetEventStreamHeaders(c)
|
|
|
+
|
|
|
+ helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
|
|
+ var geminiResponse GeminiChatResponse
|
|
|
+ err := common.DecodeJsonStr(data, &geminiResponse)
|
|
|
+ if err != nil {
|
|
|
+ common.LogError(c, "error unmarshalling stream response: "+err.Error())
|
|
|
+ return false
|
|
|
+ }
|
|
|
+
|
|
|
+ // 统计图片数量
|
|
|
+ for _, candidate := range geminiResponse.Candidates {
|
|
|
+ for _, part := range candidate.Content.Parts {
|
|
|
+ if part.InlineData != nil && part.InlineData.MimeType != "" {
|
|
|
+ imageCount++
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // 更新使用量统计
|
|
|
+ if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
|
|
|
+ usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
|
|
|
+ usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
|
|
|
+ usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
|
|
|
+ }
|
|
|
+
|
|
|
+ // 直接发送 GeminiChatResponse 响应
|
|
|
+ err = helper.ObjectData(c, geminiResponse)
|
|
|
+ if err != nil {
|
|
|
+ common.LogError(c, err.Error())
|
|
|
+ }
|
|
|
+
|
|
|
+ return true
|
|
|
+ })
|
|
|
+
|
|
|
+ if imageCount != 0 {
|
|
|
+ if usage.CompletionTokens == 0 {
|
|
|
+ usage.CompletionTokens = imageCount * 258
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // 计算最终使用量
|
|
|
+ usage.PromptTokensDetails.TextTokens = usage.PromptTokens
|
|
|
+ usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
|
|
|
+
|
|
|
+ // 结束流式响应
|
|
|
+ helper.Done(c)
|
|
|
+
|
|
|
+ return usage, nil
|
|
|
}
|