Przeglądaj źródła

feat: support group now (close #17, close #72, close #85, close #104, close #136)

Co-authored-by: quzard <[email protected]>
JustSong 2 lat temu
rodzic
commit
2ad22e1425

+ 26 - 0
common/gin.go

@@ -0,0 +1,26 @@
+package common
+
+import (
+	"bytes"
+	"encoding/json"
+	"github.com/gin-gonic/gin"
+	"io"
+)
+
+func UnmarshalBodyReusable(c *gin.Context, v any) error {
+	requestBody, err := io.ReadAll(c.Request.Body)
+	if err != nil {
+		return err
+	}
+	err = c.Request.Body.Close()
+	if err != nil {
+		return err
+	}
+	err = json.Unmarshal(requestBody, &v)
+	if err != nil {
+		return err
+	}
+	// Reset request body
+	c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
+	return nil
+}

+ 2 - 12
controller/relay.go

@@ -116,20 +116,10 @@ func relayHelper(c *gin.Context) *OpenAIErrorWithStatusCode {
 	consumeQuota := c.GetBool("consume_quota")
 	var textRequest GeneralOpenAIRequest
 	if consumeQuota || channelType == common.ChannelTypeAzure || channelType == common.ChannelTypePaLM {
-		requestBody, err := io.ReadAll(c.Request.Body)
+		err := common.UnmarshalBodyReusable(c, &textRequest)
 		if err != nil {
-			return errorWrapper(err, "read_request_body_failed", http.StatusBadRequest)
+			return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
 		}
-		err = c.Request.Body.Close()
-		if err != nil {
-			return errorWrapper(err, "close_request_body_failed", http.StatusBadRequest)
-		}
-		err = json.Unmarshal(requestBody, &textRequest)
-		if err != nil {
-			return errorWrapper(err, "unmarshal_request_body_failed", http.StatusBadRequest)
-		}
-		// Reset request body
-		c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
 	}
 	baseURL := common.ChannelBaseURLs[channelType]
 	requestURL := c.Request.URL.String()

+ 19 - 2
middleware/distributor.go

@@ -9,6 +9,10 @@ import (
 	"strconv"
 )
 
+type ModelRequest struct {
+	Model string `json:"model"`
+}
+
 func Distribute() func(c *gin.Context) {
 	return func(c *gin.Context) {
 		var channel *model.Channel
@@ -48,8 +52,21 @@ func Distribute() func(c *gin.Context) {
 			}
 		} else {
 			// Select a channel for the user
-			var err error
-			channel, err = model.GetRandomChannel()
+			var modelRequest ModelRequest
+			err := common.UnmarshalBodyReusable(c, &modelRequest)
+			if err != nil {
+				c.JSON(200, gin.H{
+					"error": gin.H{
+						"message": "无效的请求",
+						"type":    "one_api_error",
+					},
+				})
+				c.Abort()
+				return
+			}
+			userId := c.GetInt("id")
+			userGroup, _ := model.GetUserGroup(userId)
+			channel, err = model.GetRandomSatisfiedChannel(userGroup, modelRequest.Model)
 			if err != nil {
 				c.JSON(200, gin.H{
 					"error": gin.H{

+ 72 - 0
model/ability.go

@@ -0,0 +1,72 @@
+package model
+
+import (
+	"one-api/common"
+	"strings"
+)
+
+type Ability struct {
+	Group     string `json:"group" gorm:"type:varchar(32);primaryKey;autoIncrement:false"`
+	Model     string `json:"model" gorm:"primaryKey;autoIncrement:false"`
+	ChannelId int    `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
+	Enabled   bool   `json:"enabled" gorm:"default:1"`
+}
+
+func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
+	if group == "default" {
+		return GetRandomChannel()
+	}
+	ability := Ability{}
+	var err error = nil
+	if common.UsingSQLite {
+		err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RANDOM()").Limit(1).First(&ability).Error
+	} else {
+		err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("RAND()").Limit(1).First(&ability).Error
+	}
+	if err != nil {
+		return nil, err
+	}
+	channel := Channel{}
+	err = DB.First(&channel, "id = ?", ability.ChannelId).Error
+	return &channel, err
+}
+
+func (channel *Channel) AddAbilities() error {
+	models_ := strings.Split(channel.Models, ",")
+	abilities := make([]Ability, 0, len(models_))
+	for _, model := range models_ {
+		ability := Ability{
+			Group:     channel.Group,
+			Model:     model,
+			ChannelId: channel.Id,
+			Enabled:   channel.Status == common.ChannelStatusEnabled,
+		}
+		abilities = append(abilities, ability)
+	}
+	return DB.Create(&abilities).Error
+}
+
+func (channel *Channel) DeleteAbilities() error {
+	return DB.Where("channel_id = ?", channel.Id).Delete(&Ability{}).Error
+}
+
+// UpdateAbilities updates abilities of this channel.
+// Make sure the channel is completed before calling this function.
+func (channel *Channel) UpdateAbilities() error {
+	// A quick and dirty way to update abilities
+	// First delete all abilities of this channel
+	err := channel.DeleteAbilities()
+	if err != nil {
+		return err
+	}
+	// Then add new abilities
+	err = channel.AddAbilities()
+	if err != nil {
+		return err
+	}
+	return nil
+}
+
+func UpdateAbilityStatus(channelId int, status bool) error {
+	return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Update("enabled", status).Error
+}

+ 31 - 6
model/channel.go

@@ -1,7 +1,6 @@
 package model
 
 import (
-	_ "gorm.io/driver/sqlite"
 	"one-api/common"
 )
 
@@ -19,6 +18,8 @@ type Channel struct {
 	Other              string  `json:"other"`
 	Balance            float64 `json:"balance"` // in USD
 	BalanceUpdatedTime int64   `json:"balance_updated_time" gorm:"bigint"`
+	Models             string  `json:"models"`
+	Group              string  `json:"group" gorm:"type:varchar(32);default:'default'"`
 }
 
 func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
@@ -49,13 +50,12 @@ func GetChannelById(id int, selectAll bool) (*Channel, error) {
 }
 
 func GetRandomChannel() (*Channel, error) {
-	// TODO: consider weight
 	channel := Channel{}
 	var err error = nil
 	if common.UsingSQLite {
-		err = DB.Where("status = ?", common.ChannelStatusEnabled).Order("RANDOM()").Limit(1).First(&channel).Error
+		err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RANDOM()").Limit(1).First(&channel).Error
 	} else {
-		err = DB.Where("status = ?", common.ChannelStatusEnabled).Order("RAND()").Limit(1).First(&channel).Error
+		err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RAND()").Limit(1).First(&channel).Error
 	}
 	return &channel, err
 }
@@ -63,18 +63,35 @@ func GetRandomChannel() (*Channel, error) {
 func BatchInsertChannels(channels []Channel) error {
 	var err error
 	err = DB.Create(&channels).Error
-	return err
+	if err != nil {
+		return err
+	}
+	for _, channel_ := range channels {
+		err = channel_.AddAbilities()
+		if err != nil {
+			return err
+		}
+	}
+	return nil
 }
 
 func (channel *Channel) Insert() error {
 	var err error
 	err = DB.Create(channel).Error
+	if err != nil {
+		return err
+	}
+	err = channel.AddAbilities()
 	return err
 }
 
 func (channel *Channel) Update() error {
 	var err error
 	err = DB.Model(channel).Updates(channel).Error
+	if err != nil {
+		return err
+	}
+	err = channel.UpdateAbilities()
 	return err
 }
 
@@ -101,11 +118,19 @@ func (channel *Channel) UpdateBalance(balance float64) {
 func (channel *Channel) Delete() error {
 	var err error
 	err = DB.Delete(channel).Error
+	if err != nil {
+		return err
+	}
+	err = channel.DeleteAbilities()
 	return err
 }
 
 func UpdateChannelStatusById(id int, status int) {
-	err := DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error
+	err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled)
+	if err != nil {
+		common.SysError("failed to update ability status: " + err.Error())
+	}
+	err = DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error
 	if err != nil {
 		common.SysError("failed to update channel status: " + err.Error())
 	}

+ 4 - 0
model/main.go

@@ -75,6 +75,10 @@ func InitDB() (err error) {
 		if err != nil {
 			return err
 		}
+		err = db.AutoMigrate(&Ability{})
+		if err != nil {
+			return err
+		}
 		err = createRootAccountIfNeed()
 		return err
 	} else {

+ 0 - 1
model/redemption.go

@@ -2,7 +2,6 @@ package model
 
 import (
 	"errors"
-	_ "gorm.io/driver/sqlite"
 	"one-api/common"
 )
 

+ 0 - 1
model/token.go

@@ -3,7 +3,6 @@ package model
 import (
 	"errors"
 	"fmt"
-	_ "gorm.io/driver/sqlite"
 	"gorm.io/gorm"
 	"one-api/common"
 )

+ 6 - 0
model/user.go

@@ -22,6 +22,7 @@ type User struct {
 	VerificationCode string `json:"verification_code" gorm:"-:all"`                                    // this field is only for Email verification, don't save it to database!
 	AccessToken      string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
 	Quota            int    `json:"quota" gorm:"type:int;default:0"`
+	Group            string `json:"group" gorm:"type:varchar(32);default:'default'"`
 }
 
 func GetMaxUserId() int {
@@ -229,6 +230,11 @@ func GetUserEmail(id int) (email string, err error) {
 	return email, err
 }
 
+func GetUserGroup(id int) (group string, err error) {
+	err = DB.Model(&User{}).Where("id = ?", id).Select("`group`").Find(&group).Error
+	return group, err
+}
+
 func IncreaseUserQuota(id int, quota int) (err error) {
 	if quota < 0 {
 		return errors.New("quota 不能为负数!")

+ 1 - 0
router/api-router.go

@@ -63,6 +63,7 @@ func SetApiRouter(router *gin.Engine) {
 		{
 			channelRoute.GET("/", controller.GetAllChannels)
 			channelRoute.GET("/search", controller.SearchChannels)
+			channelRoute.GET("/models", controller.ListModels)
 			channelRoute.GET("/:id", controller.GetChannel)
 			channelRoute.GET("/test", controller.TestAllChannels)
 			channelRoute.GET("/test/:id", controller.TestChannel)

+ 6 - 2
router/relay-router.go

@@ -8,11 +8,15 @@ import (
 
 func SetRelayRouter(router *gin.Engine) {
 	// https://platform.openai.com/docs/api-reference/introduction
+	modelsRouter := router.Group("/v1/models")
+	modelsRouter.Use(middleware.TokenAuth())
+	{
+		modelsRouter.GET("/", controller.ListModels)
+		modelsRouter.GET("/:model", controller.RetrieveModel)
+	}
 	relayV1Router := router.Group("/v1")
 	relayV1Router.Use(middleware.TokenAuth(), middleware.Distribute())
 	{
-		relayV1Router.GET("/models", controller.ListModels)
-		relayV1Router.GET("/models/:model", controller.RetrieveModel)
 		relayV1Router.POST("/completions", controller.RelayNotImplemented)
 		relayV1Router.POST("/chat/completions", controller.Relay)
 		relayV1Router.POST("/edits", controller.RelayNotImplemented)

+ 11 - 1
web/src/components/ChannelsTable.js

@@ -4,6 +4,7 @@ import { Link } from 'react-router-dom';
 import { API, showError, showInfo, showSuccess, timestamp2string } from '../helpers';
 
 import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants';
+import { renderGroup } from '../helpers/render';
 
 function renderTimestamp(timestamp) {
   return (
@@ -264,6 +265,14 @@ const ChannelsTable = () => {
             >
               名称
             </Table.HeaderCell>
+            <Table.HeaderCell
+              style={{ cursor: 'pointer' }}
+              onClick={() => {
+                sortChannel('group');
+              }}
+            >
+              分组
+            </Table.HeaderCell>
             <Table.HeaderCell
               style={{ cursor: 'pointer' }}
               onClick={() => {
@@ -312,6 +321,7 @@ const ChannelsTable = () => {
                 <Table.Row key={channel.id}>
                   <Table.Cell>{channel.id}</Table.Cell>
                   <Table.Cell>{channel.name ? channel.name : '无'}</Table.Cell>
+                  <Table.Cell>{renderGroup(channel.group)}</Table.Cell>
                   <Table.Cell>{renderType(channel.type)}</Table.Cell>
                   <Table.Cell>{renderStatus(channel.status)}</Table.Cell>
                   <Table.Cell>
@@ -398,7 +408,7 @@ const ChannelsTable = () => {
 
         <Table.Footer>
           <Table.Row>
-            <Table.HeaderCell colSpan='7'>
+            <Table.HeaderCell colSpan='8'>
               <Button size='small' as={Link} to='/channel/add' loading={loading}>
                 添加新的渠道
               </Button>

+ 11 - 2
web/src/components/UsersTable.js

@@ -4,7 +4,7 @@ import { Link } from 'react-router-dom';
 import { API, showError, showSuccess } from '../helpers';
 
 import { ITEMS_PER_PAGE } from '../constants';
-import { renderText } from '../helpers/render';
+import { renderGroup, renderText } from '../helpers/render';
 
 function renderRole(role) {
   switch (role) {
@@ -175,6 +175,14 @@ const UsersTable = () => {
             >
               用户名
             </Table.HeaderCell>
+            <Table.HeaderCell
+              style={{ cursor: 'pointer' }}
+              onClick={() => {
+                sortUser('group');
+              }}
+            >
+              分组
+            </Table.HeaderCell>
             <Table.HeaderCell
               style={{ cursor: 'pointer' }}
               onClick={() => {
@@ -231,6 +239,7 @@ const UsersTable = () => {
                       hoverable
                     />
                   </Table.Cell>
+                  <Table.Cell>{renderGroup(user.group)}</Table.Cell>
                   <Table.Cell>{user.email ? renderText(user.email, 30) : '无'}</Table.Cell>
                   <Table.Cell>{user.quota}</Table.Cell>
                   <Table.Cell>{renderRole(user.role)}</Table.Cell>
@@ -306,7 +315,7 @@ const UsersTable = () => {
 
         <Table.Footer>
           <Table.Row>
-            <Table.HeaderCell colSpan='7'>
+            <Table.HeaderCell colSpan='8'>
               <Button size='small' as={Link} to='/user/add' loading={loading}>
                 添加新的用户
               </Button>

+ 9 - 0
web/src/helpers/render.js

@@ -1,6 +1,15 @@
+import { Label } from 'semantic-ui-react';
+
 export function renderText(text, limit) {
   if (text.length > limit) {
     return text.slice(0, limit - 3) + '...';
   }
   return text;
+}
+
+export function renderGroup(group) {
+  if (group === "") {
+    return <Label>default</Label>
+  }
+  return <Label>{group}</Label>
 }

+ 37 - 2
web/src/pages/Channel/EditChannel.js

@@ -14,10 +14,12 @@ const EditChannel = () => {
     type: 1,
     key: '',
     base_url: '',
-    other: ''
+    other: '',
+    models: [],
   };
   const [batch, setBatch] = useState(false);
   const [inputs, setInputs] = useState(originInputs);
+  const [modelOptions, setModelOptions] = useState([]);
   const handleInputChange = (e, { name, value }) => {
     console.log(name, value);
     setInputs((inputs) => ({ ...inputs, [name]: value }));
@@ -27,17 +29,36 @@ const EditChannel = () => {
     let res = await API.get(`/api/channel/${channelId}`);
     const { success, message, data } = res.data;
     if (success) {
-      data.password = '';
+      if (data.models === "") {
+        data.models = []
+      } else {
+        data.models = data.models.split(",")
+      }
       setInputs(data);
     } else {
       showError(message);
     }
     setLoading(false);
   };
+
+  const fetchModels = async () => {
+    try {
+      let res = await API.get(`/api/channel/models`);
+      setModelOptions(res.data.data.map((model) => ({
+        key: model.id,
+        text: model.id,
+        value: model.id,
+      })));
+    } catch (error) {
+      console.error('Error fetching models:', error);
+    }
+  };
+
   useEffect(() => {
     if (isEdit) {
       loadChannel().then();
     }
+    fetchModels().then();
   }, []);
 
   const submit = async () => {
@@ -50,6 +71,7 @@ const EditChannel = () => {
       localInputs.other = '2023-03-15-preview';
     }
     let res;
+    localInputs.models = localInputs.models.join(",")
     if (isEdit) {
       res = await API.put(`/api/channel/`, { ...localInputs, id: parseInt(channelId) });
     } else {
@@ -137,6 +159,19 @@ const EditChannel = () => {
               autoComplete='new-password'
             />
           </Form.Field>
+          <Form.Field>
+            <Form.Dropdown
+              label='支持的模型'
+              name='models'
+              fluid
+              multiple
+              selection
+              onChange={handleInputChange}
+              value={inputs.models}
+              autoComplete='new-password'
+              options={modelOptions}
+            />
+          </Form.Field>
           {
             batch ? <Form.Field>
               <Form.TextArea