Parcourir la source

feat: Add Gemini version settings configuration support (close #568)

[email protected] il y a 10 mois
Parent
commit
bf80d71ddf

+ 0 - 4
.env.example

@@ -50,10 +50,6 @@
 # CHANNEL_TEST_FREQUENCY=10
 # 生成默认token
 # GENERATE_DEFAULT_TOKEN=false
-# Gemini 安全设置
-# GEMINI_SAFETY_SETTING=BLOCK_NONE
-# Gemini版本设置
-# GEMINI_MODEL_MAP=gemini-1.0-pro:v1
 # Cohere 安全设置
 # COHERE_SAFETY_SETTING=NONE
 # 是否统计图片token

+ 4 - 1
README.md

@@ -94,7 +94,6 @@
 - `GET_MEDIA_TOKEN`:是否统计图片token,默认为 `true`,关闭后将不再在本地计算图片token,可能会导致和上游计费不同,此项覆盖 `GET_MEDIA_TOKEN_NOT_STREAM` 选项作用。
 - `GET_MEDIA_TOKEN_NOT_STREAM`:是否在非流(`stream=false`)情况下统计图片token,默认为 `true`。
 - `UPDATE_TASK`:是否更新异步任务(Midjourney、Suno),默认为 `true`,关闭后将不会更新任务进度。
-- `GEMINI_MODEL_MAP`:Gemini模型指定版本(v1/v1beta),使用"模型:版本"指定,","分隔,例如:-e GEMINI_MODEL_MAP="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta",为空则使用默认配置(v1beta)
 - `COHERE_SAFETY_SETTING`:Cohere模型[安全设置](https://docs.cohere.com/docs/safety-modes#overview),可选值为 `NONE`, `CONTEXTUAL`, `STRICT`,默认为 `NONE`。
 - `GEMINI_VISION_MAX_IMAGE_NUM`:Gemini模型最大图片数量,默认为 `16`,设置为 `-1` 则不限制。
 - `MAX_FILE_DOWNLOAD_MB`: 最大文件下载大小,单位 MB,默认为 `20`。
@@ -103,6 +102,10 @@
 - `NOTIFICATION_LIMIT_DURATION_MINUTE`:通知限制的持续时间(分钟),默认为 `10`。
 - `NOTIFY_LIMIT_COUNT`:用户通知在指定持续时间内的最大数量,默认为 `2`。
 
+## 已废弃的环境变量
+- ~~`GEMINI_MODEL_MAP`(已废弃)~~:改为到`设置-模型相关设置`中设置
+- ~~`GEMINI_SAFETY_SETTING`(已废弃)~~:改为到`设置-模型相关设置`中设置
+
 ## 部署
 
 > [!TIP]

+ 15 - 18
constant/env.go

@@ -1,10 +1,7 @@
 package constant
 
 import (
-	"fmt"
 	"one-api/common"
-	"os"
-	"strings"
 )
 
 var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 60)
@@ -23,9 +20,9 @@ var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
 
 var AzureDefaultAPIVersion = common.GetEnvOrDefaultString("AZURE_DEFAULT_API_VERSION", "2024-12-01-preview")
 
-var GeminiModelMap = map[string]string{
-	"gemini-1.0-pro": "v1",
-}
+//var GeminiModelMap = map[string]string{
+//	"gemini-1.0-pro": "v1",
+//}
 
 var GeminiVisionMaxImageNum = common.GetEnvOrDefault("GEMINI_VISION_MAX_IMAGE_NUM", 16)
 
@@ -33,18 +30,18 @@ var NotifyLimitCount = common.GetEnvOrDefault("NOTIFY_LIMIT_COUNT", 2)
 var NotificationLimitDurationMinute = common.GetEnvOrDefault("NOTIFICATION_LIMIT_DURATION_MINUTE", 10)
 
 func InitEnv() {
-	modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
-	if modelVersionMapStr == "" {
-		return
-	}
-	for _, pair := range strings.Split(modelVersionMapStr, ",") {
-		parts := strings.Split(pair, ":")
-		if len(parts) == 2 {
-			GeminiModelMap[parts[0]] = parts[1]
-		} else {
-			common.SysError(fmt.Sprintf("invalid model version map: %s", pair))
-		}
-	}
+	//modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
+	//if modelVersionMapStr == "" {
+	//	return
+	//}
+	//for _, pair := range strings.Split(modelVersionMapStr, ",") {
+	//	parts := strings.Split(pair, ":")
+	//	if len(parts) == 2 {
+	//		GeminiModelMap[parts[0]] = parts[1]
+	//	} else {
+	//		common.SysError(fmt.Sprintf("invalid model version map: %s", pair))
+	//	}
+	//}
 }
 
 // GenerateDefaultToken 是否生成初始令牌,默认关闭。

+ 6 - 2
model/option.go

@@ -3,6 +3,7 @@ package model
 import (
 	"one-api/common"
 	"one-api/setting"
+	"one-api/setting/model_setting"
 	"strconv"
 	"strings"
 	"time"
@@ -115,7 +116,8 @@ func InitOptionMap() {
 	common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
 	common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
 	common.OptionMap["AutomaticDisableKeywords"] = setting.AutomaticDisableKeywordsToString()
-	common.OptionMap["GeminiSafetySettings"] = setting.GeminiSafetySettingsJsonString()
+	common.OptionMap["GeminiSafetySettings"] = model_setting.GeminiSafetySettingsJsonString()
+	common.OptionMap["GeminiVersionSettings"] = model_setting.GeminiVersionSettingsJsonString()
 
 	common.OptionMapRWMutex.Unlock()
 	loadOptionsFromDatabase()
@@ -353,7 +355,9 @@ func updateOptionMap(key string, value string) (err error) {
 	case "AutomaticDisableKeywords":
 		setting.AutomaticDisableKeywordsFromString(value)
 	case "GeminiSafetySettings":
-		setting.GeminiSafetySettingFromJsonString(value)
+		model_setting.GeminiSafetySettingFromJsonString(value)
+	case "GeminiVersionSettings":
+		model_setting.GeminiVersionSettingFromJsonString(value)
 	case "StreamCacheQueueLength":
 		setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
 	}

+ 2 - 10
relay/channel/gemini/adaptor.go

@@ -7,11 +7,11 @@ import (
 	"io"
 	"net/http"
 	"one-api/common"
-	"one-api/constant"
 	"one-api/dto"
 	"one-api/relay/channel"
 	relaycommon "one-api/relay/common"
 	"one-api/service"
+	"one-api/setting/model_setting"
 
 	"strings"
 
@@ -64,15 +64,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
 }
 
 func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
-	// 从映射中获取模型名称对应的版本,如果找不到就使用 info.ApiVersion 或默认的版本 "v1beta"
-	version, beta := constant.GeminiModelMap[info.UpstreamModelName]
-	if !beta {
-		if info.ApiVersion != "" {
-			version = info.ApiVersion
-		} else {
-			version = "v1beta"
-		}
-	}
+	version := model_setting.GetGeminiVersionSetting(info.UpstreamModelName)
 
 	if strings.HasPrefix(info.UpstreamModelName, "imagen") {
 		return fmt.Sprintf("%s/%s/models/%s:predict", info.BaseUrl, version, info.UpstreamModelName), nil

+ 2 - 2
relay/channel/gemini/relay-gemini.go

@@ -11,7 +11,7 @@ import (
 	"one-api/dto"
 	relaycommon "one-api/relay/common"
 	"one-api/service"
-	"one-api/setting"
+	"one-api/setting/model_setting"
 	"strings"
 	"unicode/utf8"
 
@@ -36,7 +36,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
 	for _, category := range SafetySettingList {
 		safetySettings = append(safetySettings, GeminiChatSafetySettings{
 			Category:  category,
-			Threshold: setting.GetGeminiSafetySetting(category),
+			Threshold: model_setting.GetGeminiSafetySetting(category),
 		})
 	}
 	geminiRequest.SafetySettings = safetySettings

+ 39 - 1
setting/model_setting.go → setting/model_setting/gemini.go

@@ -1,4 +1,4 @@
-package setting
+package model_setting
 
 import (
 	"encoding/json"
@@ -43,3 +43,41 @@ func GeminiSafetySettingsJsonString() string {
 	}
 	return string(jsonString)
 }
+
+var geminiVersionSettings = map[string]string{
+	"default":        "v1beta",
+	"gemini-1.0-pro": "v1",
+}
+
+func GetGeminiVersionSetting(key string) string {
+	if value, ok := geminiVersionSettings[key]; ok {
+		return value
+	}
+	return geminiVersionSettings["default"]
+}
+
+func GeminiVersionSettingFromJsonString(jsonString string) {
+	geminiVersionSettings = map[string]string{}
+	err := json.Unmarshal([]byte(jsonString), &geminiVersionSettings)
+	if err != nil {
+		geminiVersionSettings = map[string]string{
+			"default": "v1beta",
+		}
+	}
+	// check must have default
+	if _, ok := geminiVersionSettings["default"]; !ok {
+		geminiVersionSettings["default"] = "v1beta"
+	}
+}
+
+func GeminiVersionSettingsJsonString() string {
+	// check must have default
+	if _, ok := geminiVersionSettings["default"]; !ok {
+		geminiVersionSettings["default"] = "v1beta"
+	}
+	jsonString, err := json.Marshal(geminiVersionSettings)
+	if err != nil {
+		return "{}"
+	}
+	return string(jsonString)
+}

+ 3 - 1
web/src/components/ModelSetting.js

@@ -22,6 +22,7 @@ const ModelSetting = () => {
   const { t } = useTranslation();
   let [inputs, setInputs] = useState({
     GeminiSafetySettings: '',
+    GeminiVersionSettings: '',
   });
 
   let [loading, setLoading] = useState(false);
@@ -33,7 +34,8 @@ const ModelSetting = () => {
       let newInputs = {};
       data.forEach((item) => {
         if (
-          item.key === 'GeminiSafetySettings'
+          item.key === 'GeminiSafetySettings' ||
+          item.key === 'GeminiVersionSettings'
         ) {
           item.value = JSON.stringify(JSON.parse(item.value), null, 2);
         }

+ 27 - 0
web/src/pages/Setting/Model/SettingGeminiModel.js

@@ -14,12 +14,18 @@ const GEMINI_SETTING_EXAMPLE = {
   'HARM_CATEGORY_CIVIC_INTEGRITY': 'BLOCK_NONE',
 };
 
+const GEMINI_VERSION_EXAMPLE = {
+  'default': 'v1beta',
+};
+
+
 export default function SettingGeminiModel(props) {
   const { t } = useTranslation();
 
   const [loading, setLoading] = useState(false);
   const [inputs, setInputs] = useState({
     GeminiSafetySettings: '',
+    GeminiVersionSettings: '',
   });
   const refForm = useRef();
   const [inputsRow, setInputsRow] = useState(inputs);
@@ -99,6 +105,27 @@ export default function SettingGeminiModel(props) {
                 />
               </Col>
             </Row>
+            <Row>
+              <Col span={16}>
+                <Form.TextArea
+                  label={t('Gemini版本设置')}
+                  placeholder={t('为一个 JSON 文本,例如:') + '\n' + JSON.stringify(GEMINI_VERSION_EXAMPLE, null, 2)}
+                  field={'GeminiVersionSettings'}
+                  extraText={t('default为默认设置,可单独设置每个模型的版本')}
+                  autosize={{ minRows: 6, maxRows: 12 }}
+                  trigger='blur'
+                  stopValidateWithError
+                  rules={[
+                    {
+                      validator: (rule, value) => verifyJSON(value),
+                      message: t('不是合法的 JSON 字符串')
+                    }
+                  ]}
+                  onChange={(value) => setInputs({ ...inputs, GeminiVersionSettings: value })}
+                />
+              </Col>
+            </Row>
+
             <Row>
               <Button size='default' onClick={onSubmit}>
                 {t('保存')}