cache.go 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. package model
  2. import (
  3. "errors"
  4. "fmt"
  5. "math/rand"
  6. "one-api/common"
  7. "sort"
  8. "strings"
  9. "sync"
  10. "time"
  11. )
  12. var group2model2channels map[string]map[string][]*Channel
  13. var channelsIDM map[int]*Channel
  14. var channelSyncLock sync.RWMutex
  15. func InitChannelCache() {
  16. if !common.MemoryCacheEnabled {
  17. return
  18. }
  19. newChannelId2channel := make(map[int]*Channel)
  20. var channels []*Channel
  21. DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels)
  22. for _, channel := range channels {
  23. newChannelId2channel[channel.Id] = channel
  24. }
  25. var abilities []*Ability
  26. DB.Find(&abilities)
  27. groups := make(map[string]bool)
  28. for _, ability := range abilities {
  29. groups[ability.Group] = true
  30. }
  31. newGroup2model2channels := make(map[string]map[string][]*Channel)
  32. newChannelsIDM := make(map[int]*Channel)
  33. for group := range groups {
  34. newGroup2model2channels[group] = make(map[string][]*Channel)
  35. }
  36. for _, channel := range channels {
  37. newChannelsIDM[channel.Id] = channel
  38. groups := strings.Split(channel.Group, ",")
  39. for _, group := range groups {
  40. models := strings.Split(channel.Models, ",")
  41. for _, model := range models {
  42. if _, ok := newGroup2model2channels[group][model]; !ok {
  43. newGroup2model2channels[group][model] = make([]*Channel, 0)
  44. }
  45. newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel)
  46. }
  47. }
  48. }
  49. // sort by priority
  50. for group, model2channels := range newGroup2model2channels {
  51. for model, channels := range model2channels {
  52. sort.Slice(channels, func(i, j int) bool {
  53. return channels[i].GetPriority() > channels[j].GetPriority()
  54. })
  55. newGroup2model2channels[group][model] = channels
  56. }
  57. }
  58. channelSyncLock.Lock()
  59. group2model2channels = newGroup2model2channels
  60. channelsIDM = newChannelsIDM
  61. channelSyncLock.Unlock()
  62. common.SysLog("channels synced from database")
  63. }
  64. func SyncChannelCache(frequency int) {
  65. for {
  66. time.Sleep(time.Duration(frequency) * time.Second)
  67. common.SysLog("syncing channels from database")
  68. InitChannelCache()
  69. }
  70. }
  71. func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
  72. if strings.HasPrefix(model, "gpt-4-gizmo") {
  73. model = "gpt-4-gizmo-*"
  74. }
  75. if strings.HasPrefix(model, "gpt-4o-gizmo") {
  76. model = "gpt-4o-gizmo-*"
  77. }
  78. // if memory cache is disabled, get channel directly from database
  79. if !common.MemoryCacheEnabled {
  80. return GetRandomSatisfiedChannel(group, model, retry)
  81. }
  82. channelSyncLock.RLock()
  83. channels := group2model2channels[group][model]
  84. channelSyncLock.RUnlock()
  85. if len(channels) == 0 {
  86. return nil, errors.New("channel not found")
  87. }
  88. uniquePriorities := make(map[int]bool)
  89. for _, channel := range channels {
  90. uniquePriorities[int(channel.GetPriority())] = true
  91. }
  92. var sortedUniquePriorities []int
  93. for priority := range uniquePriorities {
  94. sortedUniquePriorities = append(sortedUniquePriorities, priority)
  95. }
  96. sort.Sort(sort.Reverse(sort.IntSlice(sortedUniquePriorities)))
  97. if retry >= len(uniquePriorities) {
  98. retry = len(uniquePriorities) - 1
  99. }
  100. targetPriority := int64(sortedUniquePriorities[retry])
  101. // get the priority for the given retry number
  102. var targetChannels []*Channel
  103. for _, channel := range channels {
  104. if channel.GetPriority() == targetPriority {
  105. targetChannels = append(targetChannels, channel)
  106. }
  107. }
  108. // 平滑系数
  109. smoothingFactor := 10
  110. // Calculate the total weight of all channels up to endIdx
  111. totalWeight := 0
  112. for _, channel := range targetChannels {
  113. totalWeight += channel.GetWeight() + smoothingFactor
  114. }
  115. // Generate a random value in the range [0, totalWeight)
  116. randomWeight := rand.Intn(totalWeight)
  117. // Find a channel based on its weight
  118. for _, channel := range targetChannels {
  119. randomWeight -= channel.GetWeight() + smoothingFactor
  120. if randomWeight < 0 {
  121. return channel, nil
  122. }
  123. }
  124. // return null if no channel is not found
  125. return nil, errors.New("channel not found")
  126. }
  127. func CacheGetChannel(id int) (*Channel, error) {
  128. if !common.MemoryCacheEnabled {
  129. return GetChannelById(id, true)
  130. }
  131. channelSyncLock.RLock()
  132. defer channelSyncLock.RUnlock()
  133. c, ok := channelsIDM[id]
  134. if !ok {
  135. return nil, errors.New(fmt.Sprintf("当前渠道# %d,已不存在", id))
  136. }
  137. return c, nil
  138. }
  139. func CacheUpdateChannelStatus(id int, status int) {
  140. if !common.MemoryCacheEnabled {
  141. return
  142. }
  143. channelSyncLock.Lock()
  144. defer channelSyncLock.Unlock()
  145. if channel, ok := channelsIDM[id]; ok {
  146. channel.Status = status
  147. }
  148. }