| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320 |
- package model
- import (
- "errors"
- "fmt"
- "one-api/common"
- "strings"
- "sync"
- "github.com/samber/lo"
- "gorm.io/gorm"
- "gorm.io/gorm/clause"
- )
- type Ability struct {
- Group string `json:"group" gorm:"type:varchar(64);primaryKey;autoIncrement:false"`
- Model string `json:"model" gorm:"type:varchar(255);primaryKey;autoIncrement:false"`
- ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
- Enabled bool `json:"enabled"`
- Priority *int64 `json:"priority" gorm:"bigint;default:0;index"`
- Weight uint `json:"weight" gorm:"default:0;index"`
- Tag *string `json:"tag" gorm:"index"`
- }
- type AbilityWithChannel struct {
- Ability
- ChannelType int `json:"channel_type"`
- }
- func GetAllEnableAbilityWithChannels() ([]AbilityWithChannel, error) {
- var abilities []AbilityWithChannel
- err := DB.Table("abilities").
- Select("abilities.*, channels.type as channel_type").
- Joins("left join channels on abilities.channel_id = channels.id").
- Where("abilities.enabled = ?", true).
- Scan(&abilities).Error
- return abilities, err
- }
- func GetGroupEnabledModels(group string) []string {
- var models []string
- // Find distinct models
- DB.Table("abilities").Where(commonGroupCol+" = ? and enabled = ?", group, true).Distinct("model").Pluck("model", &models)
- return models
- }
- func GetEnabledModels() []string {
- var models []string
- // Find distinct models
- DB.Table("abilities").Where("enabled = ?", true).Distinct("model").Pluck("model", &models)
- return models
- }
- func GetAllEnableAbilities() []Ability {
- var abilities []Ability
- DB.Find(&abilities, "enabled = ?", true)
- return abilities
- }
- func getPriority(group string, model string, retry int) (int, error) {
- var priorities []int
- err := DB.Model(&Ability{}).
- Select("DISTINCT(priority)").
- Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, true).
- Order("priority DESC"). // 按优先级降序排序
- Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中
- if err != nil {
- // 处理错误
- return 0, err
- }
- if len(priorities) == 0 {
- // 如果没有查询到优先级,则返回错误
- return 0, errors.New("数据库一致性被破坏")
- }
- // 确定要使用的优先级
- var priorityToUse int
- if retry >= len(priorities) {
- // 如果重试次数大于优先级数,则使用最小的优先级
- priorityToUse = priorities[len(priorities)-1]
- } else {
- priorityToUse = priorities[retry]
- }
- return priorityToUse, nil
- }
- func getChannelQuery(group string, model string, retry int) (*gorm.DB, error) {
- maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, true)
- channelQuery := DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = (?)", group, model, true, maxPrioritySubQuery)
- if retry != 0 {
- priority, err := getPriority(group, model, retry)
- if err != nil {
- return nil, err
- } else {
- channelQuery = DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = ?", group, model, true, priority)
- }
- }
- return channelQuery, nil
- }
- func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
- var abilities []Ability
- var err error = nil
- channelQuery, err := getChannelQuery(group, model, retry)
- if err != nil {
- return nil, err
- }
- if common.UsingSQLite || common.UsingPostgreSQL {
- err = channelQuery.Order("weight DESC").Find(&abilities).Error
- } else {
- err = channelQuery.Order("weight DESC").Find(&abilities).Error
- }
- if err != nil {
- return nil, err
- }
- channel := Channel{}
- if len(abilities) > 0 {
- // Randomly choose one
- weightSum := uint(0)
- for _, ability_ := range abilities {
- weightSum += ability_.Weight + 10
- }
- // Randomly choose one
- weight := common.GetRandomInt(int(weightSum))
- for _, ability_ := range abilities {
- weight -= int(ability_.Weight) + 10
- //log.Printf("weight: %d, ability weight: %d", weight, *ability_.Weight)
- if weight <= 0 {
- channel.Id = ability_.ChannelId
- break
- }
- }
- } else {
- return nil, errors.New("channel not found")
- }
- err = DB.First(&channel, "id = ?", channel.Id).Error
- return &channel, err
- }
- func (channel *Channel) AddAbilities() error {
- models_ := strings.Split(channel.Models, ",")
- groups_ := strings.Split(channel.Group, ",")
- abilitySet := make(map[string]struct{})
- abilities := make([]Ability, 0, len(models_))
- for _, model := range models_ {
- for _, group := range groups_ {
- key := group + "|" + model
- if _, exists := abilitySet[key]; exists {
- continue
- }
- abilitySet[key] = struct{}{}
- ability := Ability{
- Group: group,
- Model: model,
- ChannelId: channel.Id,
- Enabled: channel.Status == common.ChannelStatusEnabled,
- Priority: channel.Priority,
- Weight: uint(channel.GetWeight()),
- Tag: channel.Tag,
- }
- abilities = append(abilities, ability)
- }
- }
- if len(abilities) == 0 {
- return nil
- }
- for _, chunk := range lo.Chunk(abilities, 50) {
- err := DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&chunk).Error
- if err != nil {
- return err
- }
- }
- return nil
- }
- func (channel *Channel) DeleteAbilities() error {
- return DB.Where("channel_id = ?", channel.Id).Delete(&Ability{}).Error
- }
- // UpdateAbilities updates abilities of this channel.
- // Make sure the channel is completed before calling this function.
- func (channel *Channel) UpdateAbilities(tx *gorm.DB) error {
- isNewTx := false
- // 如果没有传入事务,创建新的事务
- if tx == nil {
- tx = DB.Begin()
- if tx.Error != nil {
- return tx.Error
- }
- isNewTx = true
- defer func() {
- if r := recover(); r != nil {
- tx.Rollback()
- }
- }()
- }
- // First delete all abilities of this channel
- err := tx.Where("channel_id = ?", channel.Id).Delete(&Ability{}).Error
- if err != nil {
- if isNewTx {
- tx.Rollback()
- }
- return err
- }
- // Then add new abilities
- models_ := strings.Split(channel.Models, ",")
- groups_ := strings.Split(channel.Group, ",")
- abilitySet := make(map[string]struct{})
- abilities := make([]Ability, 0, len(models_))
- for _, model := range models_ {
- for _, group := range groups_ {
- key := group + "|" + model
- if _, exists := abilitySet[key]; exists {
- continue
- }
- abilitySet[key] = struct{}{}
- ability := Ability{
- Group: group,
- Model: model,
- ChannelId: channel.Id,
- Enabled: channel.Status == common.ChannelStatusEnabled,
- Priority: channel.Priority,
- Weight: uint(channel.GetWeight()),
- Tag: channel.Tag,
- }
- abilities = append(abilities, ability)
- }
- }
- if len(abilities) > 0 {
- for _, chunk := range lo.Chunk(abilities, 50) {
- err = tx.Clauses(clause.OnConflict{DoNothing: true}).Create(&chunk).Error
- if err != nil {
- if isNewTx {
- tx.Rollback()
- }
- return err
- }
- }
- }
- // 如果是新创建的事务,需要提交
- if isNewTx {
- return tx.Commit().Error
- }
- return nil
- }
- func UpdateAbilityStatus(channelId int, status bool) error {
- return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error
- }
- func UpdateAbilityStatusByTag(tag string, status bool) error {
- return DB.Model(&Ability{}).Where("tag = ?", tag).Select("enabled").Update("enabled", status).Error
- }
- func UpdateAbilityByTag(tag string, newTag *string, priority *int64, weight *uint) error {
- ability := Ability{}
- if newTag != nil {
- ability.Tag = newTag
- }
- if priority != nil {
- ability.Priority = priority
- }
- if weight != nil {
- ability.Weight = *weight
- }
- return DB.Model(&Ability{}).Where("tag = ?", tag).Updates(ability).Error
- }
- var fixLock = sync.Mutex{}
- func FixAbility() (int, int, error) {
- lock := fixLock.TryLock()
- if !lock {
- return 0, 0, errors.New("已经有一个修复任务在运行中,请稍后再试")
- }
- defer fixLock.Unlock()
- var channels []*Channel
- // Find all channels
- err := DB.Model(&Channel{}).Find(&channels).Error
- if err != nil {
- return 0, 0, err
- }
- if len(channels) == 0 {
- return 0, 0, nil
- }
- successCount := 0
- failCount := 0
- for _, chunk := range lo.Chunk(channels, 50) {
- ids := lo.Map(chunk, func(c *Channel, _ int) int { return c.Id })
- // Delete all abilities of this channel
- err = DB.Where("channel_id IN ?", ids).Delete(&Ability{}).Error
- if err != nil {
- common.SysError(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
- failCount += len(chunk)
- continue
- }
- // Then add new abilities
- for _, channel := range chunk {
- err = channel.AddAbilities()
- if err != nil {
- common.SysError(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error()))
- failCount++
- } else {
- successCount++
- }
- }
- }
- InitChannelCache()
- return successCount, failCount, nil
- }
|