Kaynağa Gözat

feat: Add automatic channel disabling based on configurable keywords

- Introduce AutomaticDisableKeywords setting to dynamically control channel disabling
- Implement AC search for matching error messages against disable keywords
- Add frontend UI for configuring automatic disable keywords
- Update localization with new keyword-based channel disabling feature
- Refactor sensitive word and AC search logic to support multiple keyword lists
[email protected] 10 ay önce
ebeveyn
işleme
9edb9f7a71

+ 1 - 0
common/database.go

@@ -3,5 +3,6 @@ package common
 var UsingSQLite = false
 var UsingPostgreSQL = false
 var UsingMySQL = false
+var UsingClickHouse = false
 
 var SQLitePath = "one-api.db?_busy_timeout=5000"

+ 13 - 15
controller/channel-test.go

@@ -41,36 +41,34 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
 	}
 	w := httptest.NewRecorder()
 	c, _ := gin.CreateTestContext(w)
-	
+
 	requestPath := "/v1/chat/completions"
-	
+
 	// 先判断是否为 Embedding 模型
 	if strings.Contains(strings.ToLower(testModel), "embedding") ||
-		strings.HasPrefix(testModel, "m3e") ||  // m3e 系列模型
-		strings.Contains(testModel, "bge-") ||  // bge 系列模型
+		strings.HasPrefix(testModel, "m3e") || // m3e 系列模型
+		strings.Contains(testModel, "bge-") || // bge 系列模型
 		testModel == "text-embedding-v1" ||
-		channel.Type == common.ChannelTypeMokaAI{      // 其他 embedding 模型
-		requestPath = "/v1/embeddings"  // 修改请求路径
+		channel.Type == common.ChannelTypeMokaAI { // 其他 embedding 模型
+		requestPath = "/v1/embeddings" // 修改请求路径
 	}
-	
+
 	c.Request = &http.Request{
 		Method: "POST",
-		URL:    &url.URL{Path: requestPath},  // 使用动态路径
+		URL:    &url.URL{Path: requestPath}, // 使用动态路径
 		Body:   nil,
 		Header: make(http.Header),
 	}
 
 	if testModel == "" {
-		common.SysLog(fmt.Sprintf("testModel 为空, channel 的 TestModel 是 %s", string(*channel.TestModel)))
 		if channel.TestModel != nil && *channel.TestModel != "" {
 			testModel = *channel.TestModel
 		} else {
 			if len(channel.GetModels()) > 0 {
 				testModel = channel.GetModels()[0]
 			} else {
-				testModel = "gpt-3.5-turbo"
+				testModel = "gpt-4o-mini"
 			}
-			common.SysLog(fmt.Sprintf("testModel 为空, channel 的 TestModel 为空:", string(testModel)))
 		}
 	}
 
@@ -102,7 +100,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
 
 	request := buildTestRequest(testModel)
 	meta.UpstreamModelName = testModel
-	common.SysLog(fmt.Sprintf("testing channel %d with model %s , meta %s ", channel.Id, testModel, meta))
+	common.SysLog(fmt.Sprintf("testing channel %d with model %s , meta %v ", channel.Id, testModel, meta))
 
 	adaptor.Init(meta)
 
@@ -173,9 +171,9 @@ func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
 
 	// 先判断是否为 Embedding 模型
 	if strings.Contains(strings.ToLower(model), "embedding") ||
-		strings.HasPrefix(model, "m3e") ||  // m3e 系列模型
-		strings.Contains(model, "bge-") ||  // bge 系列模型
-		model == "text-embedding-v1" {      // 其他 embedding 模型
+		strings.HasPrefix(model, "m3e") || // m3e 系列模型
+		strings.Contains(model, "bge-") || // bge 系列模型
+		model == "text-embedding-v1" { // 其他 embedding 模型
 		// Embedding 请求
 		testRequest.Input = []string{"hello world"}
 		return testRequest

+ 3 - 0
model/option.go

@@ -110,6 +110,7 @@ func InitOptionMap() {
 	common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(setting.StopOnSensitiveEnabled)
 	common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
 	common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
+	common.OptionMap["AutomaticDisableKeywords"] = setting.AutomaticDisableKeywordsToString()
 
 	common.OptionMapRWMutex.Unlock()
 	loadOptionsFromDatabase()
@@ -335,6 +336,8 @@ func updateOptionMap(key string, value string) (err error) {
 		common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64)
 	case "SensitiveWords":
 		setting.SensitiveWordsFromString(value)
+	case "AutomaticDisableKeywords":
+		setting.AutomaticDisableKeywordsFromString(value)
 	case "StreamCacheQueueLength":
 		setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
 	}

+ 4 - 14
service/channel.go

@@ -6,6 +6,7 @@ import (
 	"one-api/common"
 	relaymodel "one-api/dto"
 	"one-api/model"
+	"one-api/setting"
 	"strings"
 )
 
@@ -64,21 +65,10 @@ func ShouldDisableChannel(channelType int, err *relaymodel.OpenAIErrorWithStatus
 	case "forbidden":
 		return true
 	}
-	if strings.HasPrefix(err.Error.Message, "Your credit balance is too low") { // anthropic
-		return true
-	} else if strings.HasPrefix(err.Error.Message, "This organization has been disabled.") {
-		return true
-	} else if strings.HasPrefix(err.Error.Message, "You exceeded your current quota") {
-		return true
-	} else if strings.HasPrefix(err.Error.Message, "Permission denied") {
-		return true
-	}
 
-	if strings.Contains(err.Error.Message, "The security token included in the request is invalid") { // anthropic
-		return true
-	} else if strings.Contains(err.Error.Message, "Operation not allowed") {
-		return true
-	} else if strings.Contains(err.Error.Message, "Your account is not authorized") {
+	lowerMessage := strings.ToLower(err.Error.Message)
+	search, _ := AcSearch(lowerMessage, setting.AutomaticDisableKeywords, true)
+	if search {
 		return true
 	}
 

+ 2 - 12
service/sensitive.go

@@ -60,17 +60,7 @@ func SensitiveWordContains(text string) (bool, []string) {
 		return false, nil
 	}
 	checkText := strings.ToLower(text)
-	// 构建一个AC自动机
-	m := InitAc()
-	hits := m.MultiPatternSearch([]rune(checkText), false)
-	if len(hits) > 0 {
-		words := make([]string, 0)
-		for _, hit := range hits {
-			words = append(words, string(hit.Word))
-		}
-		return true, words
-	}
-	return false, nil
+	return AcSearch(checkText, setting.SensitiveWords, false)
 }
 
 // SensitiveWordReplace 敏感词替换,返回是否包含敏感词和替换后的文本
@@ -79,7 +69,7 @@ func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string,
 		return false, nil, text
 	}
 	checkText := strings.ToLower(text)
-	m := InitAc()
+	m := InitAc(setting.SensitiveWords)
 	hits := m.MultiPatternSearch([]rune(checkText), returnImmediately)
 	if len(hits) > 0 {
 		words := make([]string, 0)

+ 26 - 5
service/str.go

@@ -4,7 +4,6 @@ import (
 	"bytes"
 	"fmt"
 	goahocorasick "github.com/anknown/ahocorasick"
-	"one-api/setting"
 	"strings"
 )
 
@@ -57,9 +56,9 @@ func RemoveDuplicate(s []string) []string {
 	return result
 }
 
-func InitAc() *goahocorasick.Machine {
+func InitAc(words []string) *goahocorasick.Machine {
 	m := new(goahocorasick.Machine)
-	dict := readRunes()
+	dict := readRunes(words)
 	if err := m.Build(dict); err != nil {
 		fmt.Println(err)
 		return nil
@@ -67,10 +66,10 @@ func InitAc() *goahocorasick.Machine {
 	return m
 }
 
-func readRunes() [][]rune {
+func readRunes(words []string) [][]rune {
 	var dict [][]rune
 
-	for _, word := range setting.SensitiveWords {
+	for _, word := range words {
 		word = strings.ToLower(word)
 		l := bytes.TrimSpace([]byte(word))
 		dict = append(dict, bytes.Runes(l))
@@ -78,3 +77,25 @@ func readRunes() [][]rune {
 
 	return dict
 }
+
+func AcSearch(findText string, dict []string, stopImmediately bool) (bool, []string) {
+	if len(dict) == 0 {
+		return false, nil
+	}
+	if len(findText) == 0 {
+		return false, nil
+	}
+	m := InitAc(dict)
+	if m == nil {
+		return false, nil
+	}
+	hits := m.MultiPatternSearch([]rune(findText), stopImmediately)
+	if len(hits) > 0 {
+		words := make([]string, 0)
+		for _, hit := range hits {
+			words = append(words, string(hit.Word))
+		}
+		return true, words
+	}
+	return false, nil
+}

+ 27 - 0
setting/operation_setting.go

@@ -1,3 +1,30 @@
 package setting
 
+import "strings"
+
 var DemoSiteEnabled = false
+
+var AutomaticDisableKeywords = []string{
+	"Your credit balance is too low",
+	"This organization has been disabled.",
+	"You exceeded your current quota",
+	"Permission denied",
+	"The security token included in the request is invalid",
+	"Operation not allowed",
+	"Your account is not authorized",
+}
+
+func AutomaticDisableKeywordsToString() string {
+	return strings.Join(AutomaticDisableKeywords, "\n")
+}
+
+func AutomaticDisableKeywordsFromString(s string) {
+	AutomaticDisableKeywords = []string{}
+	ak := strings.Split(s, "\n")
+	for _, k := range ak {
+		k = strings.TrimSpace(k)
+		if k != "" {
+			AutomaticDisableKeywords = append(AutomaticDisableKeywords, k)
+		}
+	}
+}

+ 1 - 0
web/src/components/OperationSetting.js

@@ -59,6 +59,7 @@ const OperationSetting = () => {
     RetryTimes: 0,
     Chats: "[]",
     DemoSiteEnabled: false,
+    AutomaticDisableKeywords: '',
   });
 
   let [loading, setLoading] = useState(false);

+ 5 - 2
web/src/i18n/locales/en.json

@@ -201,7 +201,7 @@
   "相关 API 显示令牌额度而非用户额度": "Related APIs show token quota instead of user quota",
   "保存通用设置": "Save General Settings",
   "监控设置": "Monitoring Settings",
-  "最长响应时间": "Maximum Response Time",
+  "测试所有渠道的最长响应时间": "Maximum response time for testing all channels",
   "单位秒": "Unit: seconds",
   "当运行通道全部测试时": "When running all channel tests",
   "超过此时间将自动禁用通道": "Channels exceeding this time will be automatically disabled",
@@ -1246,5 +1246,8 @@
   "请输入要设置的标签名称": "Please enter the tag name to be set",
   "请输入标签名称": "Please enter the tag name",
   "支持搜索用户的 ID、用户名、显示名称和邮箱地址": "Support searching for user ID, username, display name, and email address",
-  "已注销": "Logged out"
+  "已注销": "Logged out",
+  "自动禁用关键词": "Automatic disable keywords",
+  "一行一个,不区分大小写": "One line per keyword, not case-sensitive",
+  "当上游通道返回错误中包含这些关键词时(不区分大小写),自动禁用通道": "When the upstream channel returns an error containing these keywords (not case-sensitive), automatically disable the channel"
 }

+ 15 - 2
web/src/pages/Setting/Operation/SettingsMonitoring.js

@@ -5,7 +5,7 @@ import {
   API,
   showError,
   showSuccess,
-  showWarning,
+  showWarning, verifyJSON
 } from '../../../helpers';
 import { useTranslation } from 'react-i18next';
 
@@ -17,6 +17,7 @@ export default function SettingsMonitoring(props) {
     QuotaRemindThreshold: '',
     AutomaticDisableChannelEnabled: false,
     AutomaticEnableChannelEnabled: false,
+    AutomaticDisableKeywords: '',
   });
   const refForm = useRef();
   const [inputsRow, setInputsRow] = useState(inputs);
@@ -79,7 +80,7 @@ export default function SettingsMonitoring(props) {
             <Row gutter={16}>
               <Col span={8}>
                 <Form.InputNumber
-                  label={t('最长响应时间')}
+                  label={t('测试所有渠道的最长响应时间')}
                   step={1}
                   min={0}
                   suffix={t('秒')}
@@ -144,6 +145,18 @@ export default function SettingsMonitoring(props) {
                 />
               </Col>
             </Row>
+            <Row gutter={16}>
+              <Col span={16}>
+                <Form.TextArea
+                  label={t('自动禁用关键词')}
+                  placeholder={t('一行一个,不区分大小写')}
+                  extraText={t('当上游通道返回错误中包含这些关键词时(不区分大小写),自动禁用通道')}
+                  field={'AutomaticDisableKeywords'}
+                  autosize={{ minRows: 6, maxRows: 12 }}
+                  onChange={(value) => setInputs({ ...inputs, AutomaticDisableKeywords: value })}
+                />
+              </Col>
+            </Row>
             <Row>
               <Button size='default' onClick={onSubmit}>
                 {t('保存监控设置')}