Browse Source

修改Gemini版本获取逻辑

使用GEMINI_MODEL_API环境变量覆盖默认版本映射,使用","分隔不同模型和版本
-e GEMINI_MODEL_API="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta,gemini-1.5-pro:v1beta,gemini-1.5-flash-latest:v1beta,gemini-1.5-flash-001:v1beta,gemini-1.5-flash:v1beta,gemini-ultra:v1beta,gemini-1.5-pro-exp-0801:v1beta"
HowieWu 1 year ago
parent
commit
22a98c5879
1 changed files with 25 additions and 8 deletions
  1. 25 8
      relay/channel/gemini/adaptor.go

+ 25 - 8
relay/channel/gemini/adaptor.go

@@ -6,12 +6,15 @@ import (
 	"github.com/gin-gonic/gin"
 	"io"
 	"net/http"
+	"os"
 	"one-api/dto"
 	"one-api/relay/channel"
+	"strings"
 	relaycommon "one-api/relay/common"
 )
 
 type Adaptor struct {
+	modelVersionMap map[string]string
 }
 
 func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
@@ -25,18 +28,32 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
 }
 
 func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
-}
-
-// 定义一个映射,存储模型名称和对应的版本
-var modelVersionMap = map[string]string{
-	"gemini-1.5-pro-latest":   "v1beta",
-	"gemini-1.5-flash-latest": "v1beta",
-	"gemini-ultra":            "v1beta",
+	modelVersionMapStr := os.Getenv("GEMINI_MODEL_API")
+	if modelVersionMapStr == "" {
+		a.modelVersionMap = map[string]string{ 
+			"gemini-1.5-pro-latest":   "v1beta",
+			"gemini-1.5-pro-001":      "v1beta",
+			"gemini-1.5-pro":          "v1beta",
+			"gemini-1.5-pro-exp-0801": "v1beta",
+			"gemini-1.5-flash-latest": "v1beta",
+			"gemini-1.5-flash-001":    "v1beta",
+			"gemini-1.5-flash":        "v1beta",
+			"gemini-ultra":            "v1beta",
+		}
+		return
+	}
+	a.modelVersionMap = make(map[string]string) 
+	for _, pair := range strings.Split(modelVersionMapStr, ",") {
+		parts := strings.Split(pair, ":")
+		if len(parts) == 2 {
+			a.modelVersionMap[parts[0]] = parts[1] 
+		}
+	}
 }
 
 func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 	// 从映射中获取模型名称对应的版本,如果找不到就使用 info.ApiVersion 或默认的版本 "v1"
-	version, beta := modelVersionMap[info.UpstreamModelName]
+	version, beta := a.modelVersionMap[info.UpstreamModelName]
 	if !beta {
 		if info.ApiVersion != "" {
 			version = info.ApiVersion