ability.go 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. package model
  2. import (
  3. "errors"
  4. "fmt"
  5. "one-api/common"
  6. "strings"
  7. "github.com/samber/lo"
  8. "gorm.io/gorm"
  9. "gorm.io/gorm/clause"
  10. )
  11. type Ability struct {
  12. Group string `json:"group" gorm:"type:varchar(64);primaryKey;autoIncrement:false"`
  13. Model string `json:"model" gorm:"type:varchar(255);primaryKey;autoIncrement:false"`
  14. ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
  15. Enabled bool `json:"enabled"`
  16. Priority *int64 `json:"priority" gorm:"bigint;default:0;index"`
  17. Weight uint `json:"weight" gorm:"default:0;index"`
  18. Tag *string `json:"tag" gorm:"index"`
  19. }
  20. func GetGroupModels(group string) []string {
  21. var models []string
  22. // Find distinct models
  23. DB.Table("abilities").Where(commonGroupCol+" = ? and enabled = ?", group, true).Distinct("model").Pluck("model", &models)
  24. return models
  25. }
  26. func GetEnabledModels() []string {
  27. var models []string
  28. // Find distinct models
  29. DB.Table("abilities").Where("enabled = ?", true).Distinct("model").Pluck("model", &models)
  30. return models
  31. }
  32. func GetAllEnableAbilities() []Ability {
  33. var abilities []Ability
  34. DB.Find(&abilities, "enabled = ?", true)
  35. return abilities
  36. }
  37. func getPriority(group string, model string, retry int) (int, error) {
  38. var priorities []int
  39. err := DB.Model(&Ability{}).
  40. Select("DISTINCT(priority)").
  41. Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, commonTrueVal).
  42. Order("priority DESC"). // 按优先级降序排序
  43. Pluck("priority", &priorities).Error // Pluck用于将查询的结果直接扫描到一个切片中
  44. if err != nil {
  45. // 处理错误
  46. return 0, err
  47. }
  48. if len(priorities) == 0 {
  49. // 如果没有查询到优先级,则返回错误
  50. return 0, errors.New("数据库一致性被破坏")
  51. }
  52. // 确定要使用的优先级
  53. var priorityToUse int
  54. if retry >= len(priorities) {
  55. // 如果重试次数大于优先级数,则使用最小的优先级
  56. priorityToUse = priorities[len(priorities)-1]
  57. } else {
  58. priorityToUse = priorities[retry]
  59. }
  60. return priorityToUse, nil
  61. }
  62. func getChannelQuery(group string, model string, retry int) *gorm.DB {
  63. maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, commonTrueVal)
  64. channelQuery := DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = (?)", group, model, commonTrueVal, maxPrioritySubQuery)
  65. if retry != 0 {
  66. priority, err := getPriority(group, model, retry)
  67. if err != nil {
  68. common.SysError(fmt.Sprintf("Get priority failed: %s", err.Error()))
  69. } else {
  70. channelQuery = DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = ?", group, model, commonTrueVal, priority)
  71. }
  72. }
  73. return channelQuery
  74. }
  75. func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
  76. var abilities []Ability
  77. var err error = nil
  78. channelQuery := getChannelQuery(group, model, retry)
  79. if common.UsingSQLite || common.UsingPostgreSQL {
  80. err = channelQuery.Order("weight DESC").Find(&abilities).Error
  81. } else {
  82. err = channelQuery.Order("weight DESC").Find(&abilities).Error
  83. }
  84. if err != nil {
  85. return nil, err
  86. }
  87. channel := Channel{}
  88. if len(abilities) > 0 {
  89. // Randomly choose one
  90. weightSum := uint(0)
  91. for _, ability_ := range abilities {
  92. weightSum += ability_.Weight + 10
  93. }
  94. // Randomly choose one
  95. weight := common.GetRandomInt(int(weightSum))
  96. for _, ability_ := range abilities {
  97. weight -= int(ability_.Weight) + 10
  98. //log.Printf("weight: %d, ability weight: %d", weight, *ability_.Weight)
  99. if weight <= 0 {
  100. channel.Id = ability_.ChannelId
  101. break
  102. }
  103. }
  104. } else {
  105. return nil, errors.New("channel not found")
  106. }
  107. err = DB.First(&channel, "id = ?", channel.Id).Error
  108. return &channel, err
  109. }
  110. func (channel *Channel) AddAbilities() error {
  111. models_ := strings.Split(channel.Models, ",")
  112. groups_ := strings.Split(channel.Group, ",")
  113. abilitySet := make(map[string]struct{})
  114. abilities := make([]Ability, 0, len(models_))
  115. for _, model := range models_ {
  116. for _, group := range groups_ {
  117. key := group + "|" + model
  118. if _, exists := abilitySet[key]; exists {
  119. continue
  120. }
  121. abilitySet[key] = struct{}{}
  122. ability := Ability{
  123. Group: group,
  124. Model: model,
  125. ChannelId: channel.Id,
  126. Enabled: channel.Status == common.ChannelStatusEnabled,
  127. Priority: channel.Priority,
  128. Weight: uint(channel.GetWeight()),
  129. Tag: channel.Tag,
  130. }
  131. abilities = append(abilities, ability)
  132. }
  133. }
  134. if len(abilities) == 0 {
  135. return nil
  136. }
  137. for _, chunk := range lo.Chunk(abilities, 50) {
  138. err := DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&chunk).Error
  139. if err != nil {
  140. return err
  141. }
  142. }
  143. return nil
  144. }
  145. func (channel *Channel) DeleteAbilities() error {
  146. return DB.Where("channel_id = ?", channel.Id).Delete(&Ability{}).Error
  147. }
  148. // UpdateAbilities updates abilities of this channel.
  149. // Make sure the channel is completed before calling this function.
  150. func (channel *Channel) UpdateAbilities(tx *gorm.DB) error {
  151. isNewTx := false
  152. // 如果没有传入事务,创建新的事务
  153. if tx == nil {
  154. tx = DB.Begin()
  155. if tx.Error != nil {
  156. return tx.Error
  157. }
  158. isNewTx = true
  159. defer func() {
  160. if r := recover(); r != nil {
  161. tx.Rollback()
  162. }
  163. }()
  164. }
  165. // First delete all abilities of this channel
  166. err := tx.Where("channel_id = ?", channel.Id).Delete(&Ability{}).Error
  167. if err != nil {
  168. if isNewTx {
  169. tx.Rollback()
  170. }
  171. return err
  172. }
  173. // Then add new abilities
  174. models_ := strings.Split(channel.Models, ",")
  175. groups_ := strings.Split(channel.Group, ",")
  176. abilitySet := make(map[string]struct{})
  177. abilities := make([]Ability, 0, len(models_))
  178. for _, model := range models_ {
  179. for _, group := range groups_ {
  180. key := group + "|" + model
  181. if _, exists := abilitySet[key]; exists {
  182. continue
  183. }
  184. abilitySet[key] = struct{}{}
  185. ability := Ability{
  186. Group: group,
  187. Model: model,
  188. ChannelId: channel.Id,
  189. Enabled: channel.Status == common.ChannelStatusEnabled,
  190. Priority: channel.Priority,
  191. Weight: uint(channel.GetWeight()),
  192. Tag: channel.Tag,
  193. }
  194. abilities = append(abilities, ability)
  195. }
  196. }
  197. if len(abilities) > 0 {
  198. for _, chunk := range lo.Chunk(abilities, 50) {
  199. err = tx.Clauses(clause.OnConflict{DoNothing: true}).Create(&chunk).Error
  200. if err != nil {
  201. if isNewTx {
  202. tx.Rollback()
  203. }
  204. return err
  205. }
  206. }
  207. }
  208. // 如果是新创建的事务,需要提交
  209. if isNewTx {
  210. return tx.Commit().Error
  211. }
  212. return nil
  213. }
  214. func UpdateAbilityStatus(channelId int, status bool) error {
  215. return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error
  216. }
  217. func UpdateAbilityStatusByTag(tag string, status bool) error {
  218. return DB.Model(&Ability{}).Where("tag = ?", tag).Select("enabled").Update("enabled", status).Error
  219. }
  220. func UpdateAbilityByTag(tag string, newTag *string, priority *int64, weight *uint) error {
  221. ability := Ability{}
  222. if newTag != nil {
  223. ability.Tag = newTag
  224. }
  225. if priority != nil {
  226. ability.Priority = priority
  227. }
  228. if weight != nil {
  229. ability.Weight = *weight
  230. }
  231. return DB.Model(&Ability{}).Where("tag = ?", tag).Updates(ability).Error
  232. }
  233. func FixAbility() (int, error) {
  234. var channelIds []int
  235. count := 0
  236. // Find all channel ids from channel table
  237. err := DB.Model(&Channel{}).Pluck("id", &channelIds).Error
  238. if err != nil {
  239. common.SysError(fmt.Sprintf("Get channel ids from channel table failed: %s", err.Error()))
  240. return 0, err
  241. }
  242. // Delete abilities of channels that are not in channel table - in batches to avoid too many placeholders
  243. if len(channelIds) > 0 {
  244. // Process deletion in chunks to avoid "too many placeholders" error
  245. for _, chunk := range lo.Chunk(channelIds, 100) {
  246. err = DB.Where("channel_id NOT IN (?)", chunk).Delete(&Ability{}).Error
  247. if err != nil {
  248. common.SysError(fmt.Sprintf("Delete abilities of channels (batch) that are not in channel table failed: %s", err.Error()))
  249. return 0, err
  250. }
  251. }
  252. } else {
  253. // If no channels exist, delete all abilities
  254. err = DB.Delete(&Ability{}).Error
  255. if err != nil {
  256. common.SysError(fmt.Sprintf("Delete all abilities failed: %s", err.Error()))
  257. return 0, err
  258. }
  259. common.SysLog("Delete all abilities successfully")
  260. return 0, nil
  261. }
  262. common.SysLog(fmt.Sprintf("Delete abilities of channels that are not in channel table successfully, ids: %v", channelIds))
  263. count += len(channelIds)
  264. // Use channelIds to find channel not in abilities table
  265. var abilityChannelIds []int
  266. err = DB.Table("abilities").Distinct("channel_id").Pluck("channel_id", &abilityChannelIds).Error
  267. if err != nil {
  268. common.SysError(fmt.Sprintf("Get channel ids from abilities table failed: %s", err.Error()))
  269. return count, err
  270. }
  271. var channels []Channel
  272. if len(abilityChannelIds) == 0 {
  273. err = DB.Find(&channels).Error
  274. } else {
  275. // Process query in chunks to avoid "too many placeholders" error
  276. err = nil
  277. for _, chunk := range lo.Chunk(abilityChannelIds, 100) {
  278. var channelsChunk []Channel
  279. err = DB.Where("id NOT IN (?)", chunk).Find(&channelsChunk).Error
  280. if err != nil {
  281. common.SysError(fmt.Sprintf("Find channels not in abilities table failed: %s", err.Error()))
  282. return count, err
  283. }
  284. channels = append(channels, channelsChunk...)
  285. }
  286. }
  287. for _, channel := range channels {
  288. err := channel.UpdateAbilities(nil)
  289. if err != nil {
  290. common.SysError(fmt.Sprintf("Update abilities of channel %d failed: %s", channel.Id, err.Error()))
  291. } else {
  292. common.SysLog(fmt.Sprintf("Update abilities of channel %d successfully", channel.Id))
  293. count++
  294. }
  295. }
  296. InitChannelCache()
  297. return count, nil
  298. }