Explorar o código

feat: able to set group ratio now (close #62, close #142)

JustSong %!s(int64=2) %!d(string=hai) anos
pai
achega
596446dba4

+ 30 - 0
common/group-ratio.go

@@ -0,0 +1,30 @@
+package common
+
+import "encoding/json"
+
+var GroupRatio = map[string]float64{
+	"default": 1,
+	"vip":     1,
+	"svip":    1,
+}
+
+func GroupRatio2JSONString() string {
+	jsonBytes, err := json.Marshal(GroupRatio)
+	if err != nil {
+		SysError("Error marshalling model ratio: " + err.Error())
+	}
+	return string(jsonBytes)
+}
+
+func UpdateGroupRatioByJSONString(jsonStr string) error {
+	return json.Unmarshal([]byte(jsonStr), &GroupRatio)
+}
+
+func GetGroupRatio(name string) float64 {
+	ratio, ok := GroupRatio[name]
+	if !ok {
+		SysError("Group ratio not found: " + name)
+		return 1
+	}
+	return ratio
+}

+ 19 - 0
controller/group.go

@@ -0,0 +1,19 @@
+package controller
+
+import (
+	"github.com/gin-gonic/gin"
+	"net/http"
+	"one-api/common"
+)
+
+func GetGroups(c *gin.Context) {
+	groupNames := make([]string, 0)
+	for groupName, _ := range common.GroupRatio {
+		groupNames = append(groupNames, groupName)
+	}
+	c.JSON(http.StatusOK, gin.H{
+		"success": true,
+		"message": "",
+		"data":    groupNames,
+	})
+}

+ 2 - 1
controller/relay.go

@@ -140,6 +140,7 @@ func relayHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 	channelType := c.GetInt("channel")
 	tokenId := c.GetInt("token_id")
 	consumeQuota := c.GetBool("consume_quota")
+	group := c.GetString("group")
 	var textRequest GeneralOpenAIRequest
 	if consumeQuota || channelType == common.ChannelTypeAzure || channelType == common.ChannelTypePaLM {
 		err := common.UnmarshalBodyReusable(c, &textRequest)
@@ -194,7 +195,7 @@ func relayHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 	if textRequest.MaxTokens != 0 {
 		preConsumedTokens = promptTokens + textRequest.MaxTokens
 	}
-	ratio := common.GetModelRatio(textRequest.Model)
+	ratio := common.GetModelRatio(textRequest.Model) * common.GetGroupRatio(group)
 	preConsumedQuota := int(float64(preConsumedTokens) * ratio)
 	if consumeQuota {
 		err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)

+ 3 - 2
middleware/distributor.go

@@ -16,6 +16,9 @@ type ModelRequest struct {
 
 func Distribute() func(c *gin.Context) {
 	return func(c *gin.Context) {
+		userId := c.GetInt("id")
+		userGroup, _ := model.GetUserGroup(userId)
+		c.Set("group", userGroup)
 		var channel *model.Channel
 		channelId, ok := c.Get("channelId")
 		if ok {
@@ -70,8 +73,6 @@ func Distribute() func(c *gin.Context) {
 					modelRequest.Model = "text-moderation-stable"
 				}
 			}
-			userId := c.GetInt("id")
-			userGroup, _ := model.GetUserGroup(userId)
 			channel, err = model.GetRandomSatisfiedChannel(userGroup, modelRequest.Model)
 			if err != nil {
 				c.JSON(200, gin.H{

+ 3 - 0
model/option.go

@@ -58,6 +58,7 @@ func InitOptionMap() {
 	common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
 	common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
 	common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
+	common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
 	common.OptionMap["TopUpLink"] = common.TopUpLink
 	common.OptionMapRWMutex.Unlock()
 	loadOptionsFromDatabase()
@@ -177,6 +178,8 @@ func updateOptionMap(key string, value string) (err error) {
 		common.PreConsumedQuota, _ = strconv.Atoi(value)
 	case "ModelRatio":
 		err = common.UpdateModelRatioByJSONString(value)
+	case "GroupRatio":
+		err = common.UpdateGroupRatioByJSONString(value)
 	case "TopUpLink":
 		common.TopUpLink = value
 	case "ChannelDisableThreshold":

+ 5 - 0
router/api-router.go

@@ -98,5 +98,10 @@ func SetApiRouter(router *gin.Engine) {
 		logRoute.GET("/search", middleware.AdminAuth(), controller.SearchAllLogs)
 		logRoute.GET("/self", middleware.UserAuth(), controller.GetUserLogs)
 		logRoute.GET("/self/search", middleware.UserAuth(), controller.SearchUserLogs)
+		groupRoute := apiRouter.Group("/group")
+		groupRoute.Use(middleware.AdminAuth())
+		{
+			groupRoute.GET("/", controller.GetGroups)
+		}
 	}
 }

+ 20 - 0
web/src/components/SystemSetting.js

@@ -30,6 +30,7 @@ const SystemSetting = () => {
     QuotaRemindThreshold: 0,
     PreConsumedQuota: 0,
     ModelRatio: '',
+    GroupRatio: '',
     TopUpLink: '',
     AutomaticDisableChannelEnabled: '',
     ChannelDisableThreshold: 0,
@@ -101,6 +102,7 @@ const SystemSetting = () => {
       name === 'QuotaRemindThreshold' ||
       name === 'PreConsumedQuota' ||
       name === 'ModelRatio' ||
+      name === 'GroupRatio' ||
       name === 'TopUpLink'
     ) {
       setInputs((inputs) => ({ ...inputs, [name]: value }));
@@ -131,6 +133,13 @@ const SystemSetting = () => {
       }
       await updateOption('ModelRatio', inputs.ModelRatio);
     }
+    if (originInputs['GroupRatio'] !== inputs.GroupRatio) {
+      if (!verifyJSON(inputs.GroupRatio)) {
+        showError('分组倍率不是合法的 JSON 字符串');
+        return;
+      }
+      await updateOption('GroupRatio', inputs.GroupRatio);
+    }
     if (originInputs['TopUpLink'] !== inputs.TopUpLink) {
       await updateOption('TopUpLink', inputs.TopUpLink);
     }
@@ -329,6 +338,17 @@ const SystemSetting = () => {
               placeholder='为一个 JSON 文本,键为模型名称,值为倍率'
             />
           </Form.Group>
+          <Form.Group widths='equal'>
+            <Form.TextArea
+              label='分组倍率'
+              name='GroupRatio'
+              onChange={handleInputChange}
+              style={{ minHeight: 250, fontFamily: 'JetBrains Mono, Consolas' }}
+              autoComplete='new-password'
+              value={inputs.GroupRatio}
+              placeholder='为一个 JSON 文本,键为分组名称,值为倍率'
+            />
+          </Form.Group>
           <Form.Button onClick={submitOperationConfig}>保存运营设置</Form.Button>
           <Divider />
           <Header as='h3'>

+ 4 - 0
web/src/helpers/render.js

@@ -10,6 +10,10 @@ export function renderText(text, limit) {
 export function renderGroup(group) {
   if (group === "") {
     return <Label>default</Label>
+  } else if (group === "vip" || group === "pro") {
+    return <Label color='yellow'>{group}</Label>
+  } else if (group === "svip" || group === "premium") {
+    return <Label color='red'>{group}</Label>
   }
   return <Label>{group}</Label>
 }

+ 23 - 2
web/src/pages/Channel/EditChannel.js

@@ -21,6 +21,7 @@ const EditChannel = () => {
   const [batch, setBatch] = useState(false);
   const [inputs, setInputs] = useState(originInputs);
   const [modelOptions, setModelOptions] = useState([]);
+  const [groupOptions, setGroupOptions] = useState([]);
   const [basicModels, setBasicModels] = useState([]);
   const [fullModels, setFullModels] = useState([]);
   const handleInputChange = (e, { name, value }) => {
@@ -58,11 +59,25 @@ const EditChannel = () => {
     }
   };
 
+  const fetchGroups = async () => {
+    try {
+      let res = await API.get(`/api/group`);
+      setGroupOptions(res.data.data.map((group) => ({
+        key: group,
+        text: group,
+        value: group,
+      })));
+    } catch (error) {
+      showError(error.message);
+    }
+  };
+
   useEffect(() => {
     if (isEdit) {
       loadChannel().then();
     }
     fetchModels().then();
+    fetchGroups().then();
   }, []);
 
   const submit = async () => {
@@ -167,13 +182,19 @@ const EditChannel = () => {
             />
           </Form.Field>
           <Form.Field>
-            <Form.Input
+            <Form.Dropdown
               label='分组'
+              placeholder={'请选择分组'}
               name='group'
-              placeholder={'请输入分组'}
+              fluid
+              search
+              selection
+              allowAdditions
+              additionLabel={'请在系统设置页面编辑分组倍率以添加新的分组:'}
               onChange={handleInputChange}
               value={inputs.group}
               autoComplete='new-password'
+              options={groupOptions}
             />
           </Form.Field>
           <Form.Field>

+ 25 - 3
web/src/pages/User/EditUser.js

@@ -17,11 +17,24 @@ const EditUser = () => {
     quota: 0,
     group: 'default'
   });
+  const [groupOptions, setGroupOptions] = useState([]);
   const { username, display_name, password, github_id, wechat_id, email, quota, group } =
     inputs;
   const handleInputChange = (e, { name, value }) => {
     setInputs((inputs) => ({ ...inputs, [name]: value }));
   };
+  const fetchGroups = async () => {
+    try {
+      let res = await API.get(`/api/group`);
+      setGroupOptions(res.data.data.map((group) => ({
+        key: group,
+        text: group,
+        value: group,
+      })));
+    } catch (error) {
+      showError(error.message);
+    }
+  };
 
   const loadUser = async () => {
     let res = undefined;
@@ -41,6 +54,9 @@ const EditUser = () => {
   };
   useEffect(() => {
     loadUser().then();
+    if (userId) {
+      fetchGroups().then();
+    }
   }, []);
 
   const submit = async () => {
@@ -101,13 +117,19 @@ const EditUser = () => {
           {
             userId && <>
               <Form.Field>
-                <Form.Input
+                <Form.Dropdown
                   label='分组'
+                  placeholder={'请选择分组'}
                   name='group'
-                  placeholder={'请输入用户分组'}
+                  fluid
+                  search
+                  selection
+                  allowAdditions
+                  additionLabel={'请在系统设置页面编辑分组倍率以添加新的分组:'}
                   onChange={handleInputChange}
-                  value={group}
+                  value={inputs.group}
                   autoComplete='new-password'
+                  options={groupOptions}
                 />
               </Form.Field>
               <Form.Field>