Browse Source

feat: Refactor model configuration management with new config system

- Introduce a new configuration management approach for model-specific settings
- Update Gemini settings to use the new config system with more flexible management
- Add support for dynamic configuration updates in option handling
- Modify Claude and Vertex adaptors to use new configuration methods
- Enhance web interface to support namespaced configuration keys
[email protected] 10 months ago
parent
commit
929668bead

+ 41 - 11
model/option.go

@@ -3,7 +3,7 @@ package model
 import (
 	"one-api/common"
 	"one-api/setting"
-	"one-api/setting/model_setting"
+	"one-api/setting/config"
 	"strconv"
 	"strings"
 	"time"
@@ -24,6 +24,8 @@ func AllOption() ([]*Option, error) {
 func InitOptionMap() {
 	common.OptionMapRWMutex.Lock()
 	common.OptionMap = make(map[string]string)
+
+	// 添加原有的系统配置
 	common.OptionMap["FileUploadPermission"] = strconv.Itoa(common.FileUploadPermission)
 	common.OptionMap["FileDownloadPermission"] = strconv.Itoa(common.FileDownloadPermission)
 	common.OptionMap["ImageUploadPermission"] = strconv.Itoa(common.ImageUploadPermission)
@@ -111,13 +113,16 @@ func InitOptionMap() {
 	common.OptionMap["DemoSiteEnabled"] = strconv.FormatBool(setting.DemoSiteEnabled)
 	common.OptionMap["ModelRequestRateLimitEnabled"] = strconv.FormatBool(setting.ModelRequestRateLimitEnabled)
 	common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(setting.CheckSensitiveOnPromptEnabled)
-	//common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled)
 	common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(setting.StopOnSensitiveEnabled)
 	common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
 	common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
 	common.OptionMap["AutomaticDisableKeywords"] = setting.AutomaticDisableKeywordsToString()
-	common.OptionMap["GeminiSafetySettings"] = model_setting.GeminiSafetySettingsJsonString()
-	common.OptionMap["GeminiVersionSettings"] = model_setting.GeminiVersionSettingsJsonString()
+
+	// 自动添加所有注册的模型配置
+	modelConfigs := config.GlobalConfig.ExportAllConfigs()
+	for k, v := range modelConfigs {
+		common.OptionMap[k] = v
+	}
 
 	common.OptionMapRWMutex.Unlock()
 	loadOptionsFromDatabase()
@@ -161,6 +166,13 @@ func updateOptionMap(key string, value string) (err error) {
 	common.OptionMapRWMutex.Lock()
 	defer common.OptionMapRWMutex.Unlock()
 	common.OptionMap[key] = value
+
+	// 检查是否是模型配置 - 使用更规范的方式处理
+	if handleConfigUpdate(key, value) {
+		return nil // 已由配置系统处理
+	}
+
+	// 处理传统配置项...
 	if strings.HasSuffix(key, "Permission") {
 		intValue, _ := strconv.Atoi(value)
 		switch key {
@@ -235,9 +247,6 @@ func updateOptionMap(key string, value string) (err error) {
 			setting.CheckSensitiveOnPromptEnabled = boolValue
 		case "ModelRequestRateLimitEnabled":
 			setting.ModelRequestRateLimitEnabled = boolValue
-
-		//case "CheckSensitiveOnCompletionEnabled":
-		//	constant.CheckSensitiveOnCompletionEnabled = boolValue
 		case "StopOnSensitiveEnabled":
 			setting.StopOnSensitiveEnabled = boolValue
 		case "SMTPSSLEnabled":
@@ -354,12 +363,33 @@ func updateOptionMap(key string, value string) (err error) {
 		setting.SensitiveWordsFromString(value)
 	case "AutomaticDisableKeywords":
 		setting.AutomaticDisableKeywordsFromString(value)
-	case "GeminiSafetySettings":
-		model_setting.GeminiSafetySettingFromJsonString(value)
-	case "GeminiVersionSettings":
-		model_setting.GeminiVersionSettingFromJsonString(value)
 	case "StreamCacheQueueLength":
 		setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
 	}
 	return err
 }
+
+// handleConfigUpdate 处理分层配置更新,返回是否已处理
+func handleConfigUpdate(key, value string) bool {
+	parts := strings.SplitN(key, ".", 2)
+	if len(parts) != 2 {
+		return false // 不是分层配置
+	}
+
+	configName := parts[0]
+	configKey := parts[1]
+
+	// 获取配置对象
+	cfg := config.GlobalConfig.Get(configName)
+	if cfg == nil {
+		return false // 未注册的配置
+	}
+
+	// 更新配置
+	configMap := map[string]string{
+		configKey: value,
+	}
+	config.UpdateConfigFromMap(cfg, configMap)
+
+	return true // 已处理
+}

+ 2 - 0
relay/channel/aws/adaptor.go

@@ -8,6 +8,7 @@ import (
 	"one-api/dto"
 	"one-api/relay/channel/claude"
 	relaycommon "one-api/relay/common"
+	"one-api/setting/model_setting"
 )
 
 const (
@@ -38,6 +39,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
 }
 
 func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
+	model_setting.GetClaudeSettings().WriteHeaders(req)
 	return nil
 }
 

+ 2 - 0
relay/channel/claude/adaptor.go

@@ -9,6 +9,7 @@ import (
 	"one-api/dto"
 	"one-api/relay/channel"
 	relaycommon "one-api/relay/common"
+	"one-api/setting/model_setting"
 	"strings"
 )
 
@@ -55,6 +56,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *rel
 		anthropicVersion = "2023-06-01"
 	}
 	req.Set("anthropic-version", anthropicVersion)
+	model_setting.GetClaudeSettings().WriteHeaders(req)
 	return nil
 }
 

+ 5 - 3
relay/channel/claude/relay-claude.go

@@ -10,6 +10,7 @@ import (
 	"one-api/dto"
 	relaycommon "one-api/relay/common"
 	"one-api/service"
+	"one-api/setting/model_setting"
 	"strings"
 
 	"github.com/gin-gonic/gin"
@@ -93,9 +94,10 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
 		Tools:         claudeTools,
 	}
 
-	if strings.HasSuffix(textRequest.Model, "-thinking") {
+	if model_setting.GetClaudeSettings().ThinkingAdapterEnabled &&
+		strings.HasSuffix(textRequest.Model, "-thinking") {
 		if claudeRequest.MaxTokens == 0 {
-			claudeRequest.MaxTokens = 8192
+			claudeRequest.MaxTokens = uint(model_setting.GetClaudeSettings().ThinkingAdapterMaxTokens)
 		}
 
 		// 因为BudgetTokens 必须大于1024
@@ -106,7 +108,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
 		// BudgetTokens 为 max_tokens 的 80%
 		claudeRequest.Thinking = &Thinking{
 			Type:         "enabled",
-			BudgetTokens: int(float64(claudeRequest.MaxTokens) * 0.8),
+			BudgetTokens: int(float64(claudeRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage),
 		}
 		// TODO: 临时处理
 		// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking

+ 1 - 1
relay/channel/vertex/adaptor.go

@@ -28,6 +28,7 @@ var claudeModelMap = map[string]string{
 	"claude-3-opus-20240229":     "claude-3-opus@20240229",
 	"claude-3-haiku-20240307":    "claude-3-haiku@20240307",
 	"claude-3-5-sonnet-20240620": "claude-3-5-sonnet@20240620",
+	"claude-3-7-sonnet-20250219": "claude-3-7-sonnet@20250219",
 }
 
 const anthropicVersion = "vertex-2023-10-16"
@@ -156,7 +157,6 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
 	return nil, errors.New("not implemented")
 }
 
-
 func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
 	return channel.DoApiRequest(a, c, info, requestBody)
 }

+ 32 - 63
setting/model_setting/gemini.go

@@ -1,83 +1,52 @@
 package model_setting
 
 import (
-	"encoding/json"
-	"one-api/common"
+	"one-api/setting/config"
 )
 
-var geminiSafetySettings = map[string]string{
-	"default":                       "OFF",
-	"HARM_CATEGORY_CIVIC_INTEGRITY": "BLOCK_NONE",
+// GeminiSettings 定义Gemini模型的配置
+type GeminiSettings struct {
+	SafetySettings  map[string]string `json:"safety_settings"`
+	VersionSettings map[string]string `json:"version_settings"`
 }
 
-func GetGeminiSafetySetting(key string) string {
-	if value, ok := geminiSafetySettings[key]; ok {
-		return value
-	}
-	return geminiSafetySettings["default"]
+// 默认配置
+var defaultGeminiSettings = GeminiSettings{
+	SafetySettings: map[string]string{
+		"default":                       "OFF",
+		"HARM_CATEGORY_CIVIC_INTEGRITY": "BLOCK_NONE",
+	},
+	VersionSettings: map[string]string{
+		"default":        "v1beta",
+		"gemini-1.0-pro": "v1",
+	},
 }
 
-func GeminiSafetySettingFromJsonString(jsonString string) {
-	geminiSafetySettings = map[string]string{}
-	err := json.Unmarshal([]byte(jsonString), &geminiSafetySettings)
-	if err != nil {
-		geminiSafetySettings = map[string]string{
-			"default":                       "OFF",
-			"HARM_CATEGORY_CIVIC_INTEGRITY": "BLOCK_NONE",
-		}
-	}
-	// check must have default
-	if _, ok := geminiSafetySettings["default"]; !ok {
-		geminiSafetySettings["default"] = common.GeminiSafetySetting
-	}
-}
+// 全局实例
+var geminiSettings = defaultGeminiSettings
 
-func GeminiSafetySettingsJsonString() string {
-	// check must have default
-	if _, ok := geminiSafetySettings["default"]; !ok {
-		geminiSafetySettings["default"] = common.GeminiSafetySetting
-	}
-	jsonString, err := json.Marshal(geminiSafetySettings)
-	if err != nil {
-		return "{}"
-	}
-	return string(jsonString)
+func init() {
+	// 注册到全局配置管理器
+	config.GlobalConfig.Register("gemini", &geminiSettings)
 }
 
-var geminiVersionSettings = map[string]string{
-	"default":        "v1beta",
-	"gemini-1.0-pro": "v1",
+// GetGeminiSettings 获取Gemini配置
+func GetGeminiSettings() *GeminiSettings {
+	return &geminiSettings
 }
 
-func GetGeminiVersionSetting(key string) string {
-	if value, ok := geminiVersionSettings[key]; ok {
+// GetGeminiSafetySetting 获取安全设置
+func GetGeminiSafetySetting(key string) string {
+	if value, ok := geminiSettings.SafetySettings[key]; ok {
 		return value
 	}
-	return geminiVersionSettings["default"]
-}
-
-func GeminiVersionSettingFromJsonString(jsonString string) {
-	geminiVersionSettings = map[string]string{}
-	err := json.Unmarshal([]byte(jsonString), &geminiVersionSettings)
-	if err != nil {
-		geminiVersionSettings = map[string]string{
-			"default": "v1beta",
-		}
-	}
-	// check must have default
-	if _, ok := geminiVersionSettings["default"]; !ok {
-		geminiVersionSettings["default"] = "v1beta"
-	}
+	return geminiSettings.SafetySettings["default"]
 }
 
-func GeminiVersionSettingsJsonString() string {
-	// check must have default
-	if _, ok := geminiVersionSettings["default"]; !ok {
-		geminiVersionSettings["default"] = "v1beta"
-	}
-	jsonString, err := json.Marshal(geminiVersionSettings)
-	if err != nil {
-		return "{}"
+// GetGeminiVersionSetting 获取版本设置
+func GetGeminiVersionSetting(key string) string {
+	if value, ok := geminiSettings.VersionSettings[key]; ok {
+		return value
 	}
-	return string(jsonString)
+	return geminiSettings.VersionSettings["default"]
 }

+ 6 - 6
web/src/pages/Setting/Model/SettingGeminiModel.js

@@ -24,8 +24,8 @@ export default function SettingGeminiModel(props) {
 
   const [loading, setLoading] = useState(false);
   const [inputs, setInputs] = useState({
-    GeminiSafetySettings: '',
-    GeminiVersionSettings: '',
+    'gemini.safety_settings': '',
+    'gemini.version_settings': '',
   });
   const refForm = useRef();
   const [inputsRow, setInputsRow] = useState(inputs);
@@ -90,7 +90,7 @@ export default function SettingGeminiModel(props) {
                 <Form.TextArea
                   label={t('Gemini安全设置')}
                   placeholder={t('为一个 JSON 文本,例如:') + '\n' + JSON.stringify(GEMINI_SETTING_EXAMPLE, null, 2)}
-                  field={'GeminiSafetySettings'}
+                  field={'gemini.safety_settings'}
                   extraText={t('default为默认设置,可单独设置每个分类的安全等级')}
                   autosize={{ minRows: 6, maxRows: 12 }}
                   trigger='blur'
@@ -101,7 +101,7 @@ export default function SettingGeminiModel(props) {
                       message: t('不是合法的 JSON 字符串')
                     }
                   ]}
-                  onChange={(value) => setInputs({ ...inputs, GeminiSafetySettings: value })}
+                  onChange={(value) => setInputs({ ...inputs, 'gemini.safety_settings': value })}
                 />
               </Col>
             </Row>
@@ -110,7 +110,7 @@ export default function SettingGeminiModel(props) {
                 <Form.TextArea
                   label={t('Gemini版本设置')}
                   placeholder={t('为一个 JSON 文本,例如:') + '\n' + JSON.stringify(GEMINI_VERSION_EXAMPLE, null, 2)}
-                  field={'GeminiVersionSettings'}
+                  field={'gemini.version_settings'}
                   extraText={t('default为默认设置,可单独设置每个模型的版本')}
                   autosize={{ minRows: 6, maxRows: 12 }}
                   trigger='blur'
@@ -121,7 +121,7 @@ export default function SettingGeminiModel(props) {
                       message: t('不是合法的 JSON 字符串')
                     }
                   ]}
-                  onChange={(value) => setInputs({ ...inputs, GeminiVersionSettings: value })}
+                  onChange={(value) => setInputs({ ...inputs, 'gemini.version_settings': value })}
                 />
               </Col>
             </Row>