Преглед изворни кода

feat: 优化Gemini模型版本获取逻辑

CalciumIon пре 1 година
родитељ
комит
e504665f68
4 измењених фајлова са 35 додато и 26 уклоњено
  1. 1 1
      README.md
  2. 29 0
      constant/env.go
  3. 2 0
      main.go
  4. 3 25
      relay/channel/gemini/adaptor.go

+ 1 - 1
README.md

@@ -64,7 +64,7 @@
 - `GET_MEDIA_TOKEN`:是统计图片token,默认为 `true`,关闭后将不再在本地计算图片token,可能会导致和上游计费不同,此项覆盖 `GET_MEDIA_TOKEN_NOT_STREAM` 选项作用。
 - `GET_MEDIA_TOKEN_NOT_STREAM`:是否在非流(`stream=false`)情况下统计图片token,默认为 `true`。
 - `UPDATE_TASK`:是否更新异步任务(Midjourney、Suno),默认为 `true`,关闭后将不会更新任务进度。
-- `GEMINI_MODEL_API`:Gemini模型指定版本(v1/v1beta),如果配置会覆盖默认配置需要完整给出全部v1beta模型,使用模型:版本指定,","分隔,例如:-e GEMINI_MODEL_API="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta",为空则使用默认配置:gemini-1.5-pro-latest,gemini-1.5-pro-001,gemini-1.5-pro,gemini-1.5-pro-exp-0801,gemini-1.5-flash-latest,gemini-1.5-flash-001,gemini-1.5-flash,gemini-ultra模型为v1beta,其他为v1。
+- `GEMINI_MODEL_MAP`:Gemini模型指定版本(v1/v1beta),使用模型:版本指定,","分隔,例如:-e GEMINI_MODEL_API="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta",为空则使用默认配置
 
 ## 部署
 ### 部署要求

+ 29 - 0
constant/env.go

@@ -1,7 +1,10 @@
 package constant
 
 import (
+	"fmt"
 	"one-api/common"
+	"os"
+	"strings"
 )
 
 var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 30)
@@ -15,3 +18,29 @@ var GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
 var GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
 
 var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
+
+var GeminiModelMap = map[string]string{
+	"gemini-1.5-pro-latest":   "v1beta",
+	"gemini-1.5-pro-001":      "v1beta",
+	"gemini-1.5-pro":          "v1beta",
+	"gemini-1.5-pro-exp-0801": "v1beta",
+	"gemini-1.5-flash-latest": "v1beta",
+	"gemini-1.5-flash-001":    "v1beta",
+	"gemini-1.5-flash":        "v1beta",
+	"gemini-ultra":            "v1beta",
+}
+
+func InitEnv() {
+	modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
+	if modelVersionMapStr == "" {
+		return
+	}
+	for _, pair := range strings.Split(modelVersionMapStr, ",") {
+		parts := strings.Split(pair, ":")
+		if len(parts) == 2 {
+			GeminiModelMap[parts[0]] = parts[1]
+		} else {
+			common.SysError(fmt.Sprintf("invalid model version map: %s", pair))
+		}
+	}
+}

+ 2 - 0
main.go

@@ -55,6 +55,8 @@ func main() {
 		common.FatalLog("failed to initialize Redis: " + err.Error())
 	}
 
+	// Initialize constants
+	constant.InitEnv()
 	// Initialize options
 	model.InitOptionMap()
 	if common.RedisEnabled {

+ 3 - 25
relay/channel/gemini/adaptor.go

@@ -6,15 +6,13 @@ import (
 	"github.com/gin-gonic/gin"
 	"io"
 	"net/http"
-	"os"
+	"one-api/constant"
 	"one-api/dto"
 	"one-api/relay/channel"
-	"strings"
 	relaycommon "one-api/relay/common"
 )
 
 type Adaptor struct {
-	modelVersionMap map[string]string
 }
 
 func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
@@ -28,32 +26,12 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
 }
 
 func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
-	modelVersionMapStr := os.Getenv("GEMINI_MODEL_API")
-	if modelVersionMapStr == "" {
-		a.modelVersionMap = map[string]string{ 
-			"gemini-1.5-pro-latest":   "v1beta",
-			"gemini-1.5-pro-001":      "v1beta",
-			"gemini-1.5-pro":          "v1beta",
-			"gemini-1.5-pro-exp-0801": "v1beta",
-			"gemini-1.5-flash-latest": "v1beta",
-			"gemini-1.5-flash-001":    "v1beta",
-			"gemini-1.5-flash":        "v1beta",
-			"gemini-ultra":            "v1beta",
-		}
-		return
-	}
-	a.modelVersionMap = make(map[string]string) 
-	for _, pair := range strings.Split(modelVersionMapStr, ",") {
-		parts := strings.Split(pair, ":")
-		if len(parts) == 2 {
-			a.modelVersionMap[parts[0]] = parts[1] 
-		}
-	}
+
 }
 
 func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 	// 从映射中获取模型名称对应的版本,如果找不到就使用 info.ApiVersion 或默认的版本 "v1"
-	version, beta := a.modelVersionMap[info.UpstreamModelName]
+	version, beta := constant.GeminiModelMap[info.UpstreamModelName]
 	if !beta {
 		if info.ApiVersion != "" {
 			version = info.ApiVersion