Преглед изворни кода

feat: 加入渠道加权随机功能

CaIon пре 2 година
родитељ
комит
bdd611fd33
5 измењених фајлова са 86 додато и 8 уклоњено
  1. 5 0
      common/utils.go
  2. 30 5
      model/ability.go
  3. 24 2
      model/cache.go
  4. 7 0
      model/channel.go
  5. 20 1
      web/src/components/ChannelsTable.js

+ 5 - 0
common/utils.go

@@ -168,6 +168,11 @@ func GetRandomString(length int) string {
 	return string(key)
 }
 
+func GetRandomInt(max int) int {
+	//rand.Seed(time.Now().UnixNano())
+	return rand.Intn(max)
+}
+
 func GetTimestamp() int64 {
 	return time.Now().Unix()
 }

+ 30 - 5
model/ability.go

@@ -11,6 +11,7 @@ type Ability struct {
 	ChannelId int    `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
 	Enabled   bool   `json:"enabled"`
 	Priority  *int64 `json:"priority" gorm:"bigint;default:0;index"`
+	Weight    uint   `json:"weight" gorm:"default:0;index"`
 }
 
 func GetGroupModels(group string) []string {
@@ -25,7 +26,7 @@ func GetGroupModels(group string) []string {
 }
 
 func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
-	ability := Ability{}
+	var abilities []Ability
 	groupCol := "`group`"
 	trueVal := "1"
 	if common.UsingPostgreSQL {
@@ -37,16 +38,39 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
 	maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
 	channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
 	if common.UsingSQLite || common.UsingPostgreSQL {
-		err = channelQuery.Order("RANDOM()").First(&ability).Error
+		err = channelQuery.Order("weight DESC").Find(&abilities).Error
 	} else {
-		err = channelQuery.Order("RAND()").First(&ability).Error
+		err = channelQuery.Order("weight DESC").Find(&abilities).Error
 	}
 	if err != nil {
 		return nil, err
 	}
 	channel := Channel{}
-	channel.Id = ability.ChannelId
-	err = DB.First(&channel, "id = ?", ability.ChannelId).Error
+	if len(abilities) > 0 {
+		// Randomly choose one
+		weightSum := uint(0)
+		for _, ability_ := range abilities {
+			weightSum += ability_.Weight
+		}
+		if weightSum == 0 {
+			// All weight is 0, randomly choose one
+			channel.Id = abilities[common.GetRandomInt(len(abilities))].ChannelId
+		} else {
+			// Randomly choose one
+			weight := common.GetRandomInt(int(weightSum))
+			for _, ability_ := range abilities {
+				weight -= int(ability_.Weight)
+				//log.Printf("weight: %d, ability weight: %d", weight, *ability_.Weight)
+				if weight <= 0 {
+					channel.Id = ability_.ChannelId
+					break
+				}
+			}
+		}
+	} else {
+		return nil, nil
+	}
+	err = DB.First(&channel, "id = ?", channel.Id).Error
 	return &channel, err
 }
 
@@ -62,6 +86,7 @@ func (channel *Channel) AddAbilities() error {
 				ChannelId: channel.Id,
 				Enabled:   channel.Status == common.ChannelStatusEnabled,
 				Priority:  channel.Priority,
+				Weight:    uint(channel.GetWeight()),
 			}
 			abilities = append(abilities, ability)
 		}

+ 24 - 2
model/cache.go

@@ -198,6 +198,7 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
 		model = "gpt-4-gizmo-*"
 	}
 
+	// if memory cache is disabled, get channel directly from database
 	if !common.MemoryCacheEnabled {
 		return GetRandomSatisfiedChannel(group, model)
 	}
@@ -218,8 +219,29 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
 			}
 		}
 	}
-	idx := rand.Intn(endIdx)
-	return channels[idx], nil
+	// Calculate the total weight of all channels up to endIdx
+	totalWeight := 0
+	for _, channel := range channels[:endIdx] {
+		totalWeight += channel.GetWeight()
+	}
+
+	if totalWeight == 0 {
+		// If all weights are 0, select a channel randomly
+		return channels[rand.Intn(endIdx)], nil
+	}
+
+	// Generate a random value in the range [0, totalWeight)
+	randomWeight := rand.Intn(totalWeight)
+
+	// Find a channel based on its weight
+	for _, channel := range channels[:endIdx] {
+		randomWeight -= channel.GetWeight()
+		if randomWeight <= 0 {
+			return channel, nil
+		}
+	}
+	// return the last channel if no channel is found
+	return channels[endIdx-1], nil
 }
 
 func CacheGetChannel(id int) (*Channel, error) {

+ 7 - 0
model/channel.go

@@ -113,6 +113,13 @@ func (channel *Channel) GetPriority() int64 {
 	return *channel.Priority
 }
 
+func (channel *Channel) GetWeight() int {
+	if channel.Weight == nil {
+		return 0
+	}
+	return int(*channel.Weight)
+}
+
 func (channel *Channel) GetBaseURL() string {
 	if channel.BaseURL == nil {
 		return ""

+ 20 - 1
web/src/components/ChannelsTable.js

@@ -163,7 +163,7 @@ const ChannelsTable = () => {
                     <div>
                         <InputNumber
                             style={{width: 70}}
-                            name='name'
+                            name='priority'
                             onChange={value => {
                                 manageChannel(record.id, 'priority', record, value);
                             }}
@@ -174,6 +174,25 @@ const ChannelsTable = () => {
                 );
             },
         },
+        {
+            title: '权重',
+            dataIndex: 'weight',
+            render: (text, record, index) => {
+                return (
+                    <div>
+                        <InputNumber
+                            style={{width: 70}}
+                            name='weight'
+                            onChange={value => {
+                                manageChannel(record.id, 'weight', record, value);
+                            }}
+                            defaultValue={record.weight}
+                            min={0}
+                        />
+                    </div>
+                );
+            },
+        },
         {
             title: '',
             dataIndex: 'operate',