Przeglądaj źródła

feat: Add Gemini safety settings configuration support (close #703)

[email protected] 10 miesięcy temu
rodzic
commit
e19b244e73

+ 3 - 0
model/option.go

@@ -115,6 +115,7 @@ func InitOptionMap() {
 	common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
 	common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
 	common.OptionMap["AutomaticDisableKeywords"] = setting.AutomaticDisableKeywordsToString()
+	common.OptionMap["GeminiSafetySettings"] = setting.GeminiSafetySettingsJsonString()
 
 	common.OptionMapRWMutex.Unlock()
 	loadOptionsFromDatabase()
@@ -351,6 +352,8 @@ func updateOptionMap(key string, value string) (err error) {
 		setting.SensitiveWordsFromString(value)
 	case "AutomaticDisableKeywords":
 		setting.AutomaticDisableKeywordsFromString(value)
+	case "GeminiSafetySettings":
+		setting.GeminiSafetySettingFromJsonString(value)
 	case "StreamCacheQueueLength":
 		setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
 	}

+ 8 - 0
relay/channel/gemini/constant.go

@@ -20,4 +20,12 @@ var ModelList = []string{
 	"imagen-3.0-generate-002",
 }
 
+var SafetySettingList = []string{
+	"HARM_CATEGORY_HARASSMENT",
+	"HARM_CATEGORY_VIOLENCE",
+	"HARM_CATEGORY_SEXUALLY_EXPLICIT",
+	"HARM_CATEGORY_DANGEROUS_CONTENT",
+	"HARM_CATEGORY_CIVIC_INTEGRITY",
+}
+
 var ChannelName = "google gemini"

+ 11 - 22
relay/channel/gemini/relay-gemini.go

@@ -11,6 +11,7 @@ import (
 	"one-api/dto"
 	relaycommon "one-api/relay/common"
 	"one-api/service"
+	"one-api/setting"
 	"strings"
 	"unicode/utf8"
 
@@ -22,28 +23,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
 
 	geminiRequest := GeminiChatRequest{
 		Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
-		SafetySettings: []GeminiChatSafetySettings{
-			{
-				Category:  "HARM_CATEGORY_HARASSMENT",
-				Threshold: common.GeminiSafetySetting,
-			},
-			{
-				Category:  "HARM_CATEGORY_HATE_SPEECH",
-				Threshold: common.GeminiSafetySetting,
-			},
-			{
-				Category:  "HARM_CATEGORY_SEXUALLY_EXPLICIT",
-				Threshold: common.GeminiSafetySetting,
-			},
-			{
-				Category:  "HARM_CATEGORY_DANGEROUS_CONTENT",
-				Threshold: common.GeminiSafetySetting,
-			},
-			{
-				Category:  "HARM_CATEGORY_CIVIC_INTEGRITY",
-				Threshold: common.GeminiSafetySetting,
-			},
-		},
+		//SafetySettings: []GeminiChatSafetySettings{},
 		GenerationConfig: GeminiChatGenerationConfig{
 			Temperature:     textRequest.Temperature,
 			TopP:            textRequest.TopP,
@@ -52,6 +32,15 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) (*GeminiChatReque
 		},
 	}
 
+	safetySettings := make([]GeminiChatSafetySettings, 0, len(SafetySettingList))
+	for _, category := range SafetySettingList {
+		safetySettings = append(safetySettings, GeminiChatSafetySettings{
+			Category:  category,
+			Threshold: setting.GetGeminiSafetySetting(category),
+		})
+	}
+	geminiRequest.SafetySettings = safetySettings
+
 	// openaiContent.FuncToToolCalls()
 	if textRequest.Tools != nil {
 		functions := make([]dto.FunctionCall, 0, len(textRequest.Tools))

+ 45 - 0
setting/model_setting.go

@@ -0,0 +1,45 @@
+package setting
+
+import (
+	"encoding/json"
+	"one-api/common"
+)
+
+var geminiSafetySettings = map[string]string{
+	"default":                       "OFF",
+	"HARM_CATEGORY_CIVIC_INTEGRITY": "BLOCK_NONE",
+}
+
+func GetGeminiSafetySetting(key string) string {
+	if value, ok := geminiSafetySettings[key]; ok {
+		return value
+	}
+	return geminiSafetySettings["default"]
+}
+
+func GeminiSafetySettingFromJsonString(jsonString string) {
+	geminiSafetySettings = map[string]string{}
+	err := json.Unmarshal([]byte(jsonString), &geminiSafetySettings)
+	if err != nil {
+		geminiSafetySettings = map[string]string{
+			"default":                       "OFF",
+			"HARM_CATEGORY_CIVIC_INTEGRITY": "BLOCK_NONE",
+		}
+	}
+	// check must have default
+	if _, ok := geminiSafetySettings["default"]; !ok {
+		geminiSafetySettings["default"] = common.GeminiSafetySetting
+	}
+}
+
+func GeminiSafetySettingsJsonString() string {
+	// check must have default
+	if _, ok := geminiSafetySettings["default"]; !ok {
+		geminiSafetySettings["default"] = common.GeminiSafetySetting
+	}
+	jsonString, err := json.Marshal(geminiSafetySettings)
+	if err != nil {
+		return "{}"
+	}
+	return string(jsonString)
+}

+ 82 - 0
web/src/components/ModelSetting.js

@@ -0,0 +1,82 @@
+import React, { useEffect, useState } from 'react';
+import { Card, Spin, Tabs } from '@douyinfe/semi-ui';
+import SettingsGeneral from '../pages/Setting/Operation/SettingsGeneral.js';
+import SettingsDrawing from '../pages/Setting/Operation/SettingsDrawing.js';
+import SettingsSensitiveWords from '../pages/Setting/Operation/SettingsSensitiveWords.js';
+import SettingsLog from '../pages/Setting/Operation/SettingsLog.js';
+import SettingsDataDashboard from '../pages/Setting/Operation/SettingsDataDashboard.js';
+import SettingsMonitoring from '../pages/Setting/Operation/SettingsMonitoring.js';
+import SettingsCreditLimit from '../pages/Setting/Operation/SettingsCreditLimit.js';
+import SettingsMagnification from '../pages/Setting/Operation/SettingsMagnification.js';
+import ModelSettingsVisualEditor from '../pages/Setting/Operation/ModelSettingsVisualEditor.js';
+import GroupRatioSettings from '../pages/Setting/Operation/GroupRatioSettings.js';
+import ModelRatioSettings from '../pages/Setting/Operation/ModelRatioSettings.js';
+
+
+import { API, showError, showSuccess } from '../helpers';
+import SettingsChats from '../pages/Setting/Operation/SettingsChats.js';
+import { useTranslation } from 'react-i18next';
+import SettingGeminiModel from '../pages/Setting/Model/SettingGeminiModel.js';
+
+const ModelSetting = () => {
+  const { t } = useTranslation();
+  let [inputs, setInputs] = useState({
+    GeminiSafetySettings: '',
+  });
+
+  let [loading, setLoading] = useState(false);
+
+  const getOptions = async () => {
+    const res = await API.get('/api/option/');
+    const { success, message, data } = res.data;
+    if (success) {
+      let newInputs = {};
+      data.forEach((item) => {
+        if (
+          item.key === 'GeminiSafetySettings'
+        ) {
+          item.value = JSON.stringify(JSON.parse(item.value), null, 2);
+        }
+        if (
+          item.key.endsWith('Enabled')
+        ) {
+          newInputs[item.key] = item.value === 'true' ? true : false;
+        } else {
+          newInputs[item.key] = item.value;
+        }
+      });
+
+      setInputs(newInputs);
+    } else {
+      showError(message);
+    }
+  };
+  async function onRefresh() {
+    try {
+      setLoading(true);
+      await getOptions();
+      // showSuccess('刷新成功');
+    } catch (error) {
+      showError('刷新失败');
+    } finally {
+      setLoading(false);
+    }
+  }
+
+  useEffect(() => {
+    onRefresh();
+  }, []);
+
+  return (
+    <>
+      <Spin spinning={loading} size='large'>
+        {/* Gemini */}
+        <Card style={{ marginTop: '10px' }}>
+          <SettingGeminiModel options={inputs} refresh={onRefresh} />
+        </Card>
+      </Spin>
+    </>
+  );
+};
+
+export default ModelSetting;

+ 112 - 0
web/src/pages/Setting/Model/SettingGeminiModel.js

@@ -0,0 +1,112 @@
+import React, { useEffect, useState, useRef } from 'react';
+import { Button, Col, Form, Row, Spin } from '@douyinfe/semi-ui';
+import {
+  compareObjects,
+  API,
+  showError,
+  showSuccess,
+  showWarning, verifyJSON
+} from '../../../helpers';
+import { useTranslation } from 'react-i18next';
+
+const GEMINI_SETTING_EXAMPLE = {
+  'default': 'OFF',
+  'HARM_CATEGORY_CIVIC_INTEGRITY': 'BLOCK_NONE',
+};
+
+export default function SettingGeminiModel(props) {
+  const { t } = useTranslation();
+
+  const [loading, setLoading] = useState(false);
+  const [inputs, setInputs] = useState({
+    GeminiSafetySettings: '',
+  });
+  const refForm = useRef();
+  const [inputsRow, setInputsRow] = useState(inputs);
+
+  function onSubmit() {
+    const updateArray = compareObjects(inputs, inputsRow);
+    if (!updateArray.length) return showWarning(t('你似乎并没有修改什么'));
+    const requestQueue = updateArray.map((item) => {
+      let value = '';
+      if (typeof inputs[item.key] === 'boolean') {
+        value = String(inputs[item.key]);
+      } else {
+        value = inputs[item.key];
+      }
+      return API.put('/api/option/', {
+        key: item.key,
+        value,
+      });
+    });
+    setLoading(true);
+    Promise.all(requestQueue)
+      .then((res) => {
+        if (requestQueue.length === 1) {
+          if (res.includes(undefined)) return;
+        } else if (requestQueue.length > 1) {
+          if (res.includes(undefined)) return showError(t('部分保存失败,请重试'));
+        }
+        showSuccess(t('保存成功'));
+        props.refresh();
+      })
+      .catch(() => {
+        showError(t('保存失败,请重试'));
+      })
+      .finally(() => {
+        setLoading(false);
+      });
+  }
+
+  useEffect(() => {
+    const currentInputs = {};
+    for (let key in props.options) {
+      if (Object.keys(inputs).includes(key)) {
+        currentInputs[key] = props.options[key];
+      }
+    }
+    setInputs(currentInputs);
+    setInputsRow(structuredClone(currentInputs));
+    refForm.current.setValues(currentInputs);
+  }, [props.options]);
+
+  return (
+    <>
+      <Spin spinning={loading}>
+        <Form
+          values={inputs}
+          getFormApi={(formAPI) => (refForm.current = formAPI)}
+          style={{ marginBottom: 15 }}
+        >
+          <Form.Section text={t('Gemini设置')}>
+            <Row>
+              <Col span={16}>
+                <Form.TextArea
+                  label={t('Gemini安全设置')}
+                  placeholder={t('为一个 JSON 文本,例如:') + '\n' + JSON.stringify(GEMINI_SETTING_EXAMPLE, null, 2)}
+                  field={'GeminiSafetySettings'}
+                  extraText={t('default为默认设置,可单独设置每个分类的安全等级')}
+                  autosize={{ minRows: 6, maxRows: 12 }}
+                  trigger='blur'
+                  stopValidateWithError
+                  rules={[
+                    {
+                      validator: (rule, value) => verifyJSON(value),
+                      message: t('不是合法的 JSON 字符串')
+                    }
+                  ]}
+                  onChange={(value) => setInputs({ ...inputs, GeminiSafetySettings: value })}
+                />
+              </Col>
+            </Row>
+            <Row>
+              <Button size='default' onClick={onSubmit}>
+                {t('保存')}
+              </Button>
+            </Row>
+          </Form.Section>
+        </Form>
+      </Spin>
+    </>
+  );
+}

+ 6 - 0
web/src/pages/Setting/index.js

@@ -9,6 +9,7 @@ import OtherSetting from '../../components/OtherSetting';
 import PersonalSetting from '../../components/PersonalSetting';
 import OperationSetting from '../../components/OperationSetting';
 import RateLimitSetting from '../../components/RateLimitSetting.js';
+import ModelSetting from '../../components/ModelSetting.js';
 
 const Setting = () => {
   const { t } = useTranslation();
@@ -34,6 +35,11 @@ const Setting = () => {
       content: <RateLimitSetting />,
       itemKey: 'ratelimit',
     });
+    panes.push({
+      tab: t('模型相关设置'),
+      content: <ModelSetting />,
+      itemKey: 'models',
+    });
     panes.push({
       tab: t('系统设置'),
       content: <SystemSetting />,