cache.go 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. package model
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "math/rand"
  7. "one-api/common"
  8. "sort"
  9. "strconv"
  10. "strings"
  11. "sync"
  12. "time"
  13. )
  14. var (
  15. TokenCacheSeconds = common.SyncFrequency
  16. UserId2GroupCacheSeconds = common.SyncFrequency
  17. UserId2QuotaCacheSeconds = common.SyncFrequency
  18. UserId2StatusCacheSeconds = common.SyncFrequency
  19. )
  20. func CacheGetTokenByKey(key string) (*Token, error) {
  21. keyCol := "`key`"
  22. if common.UsingPostgreSQL {
  23. keyCol = `"key"`
  24. }
  25. var token Token
  26. if !common.RedisEnabled {
  27. err := DB.Where(keyCol+" = ?", key).First(&token).Error
  28. return &token, err
  29. }
  30. tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key))
  31. if err != nil {
  32. err := DB.Where(keyCol+" = ?", key).First(&token).Error
  33. if err != nil {
  34. return nil, err
  35. }
  36. jsonBytes, err := json.Marshal(token)
  37. if err != nil {
  38. return nil, err
  39. }
  40. err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second)
  41. if err != nil {
  42. common.SysError("Redis set token error: " + err.Error())
  43. }
  44. return &token, nil
  45. }
  46. err = json.Unmarshal([]byte(tokenObjectString), &token)
  47. return &token, err
  48. }
  49. func CacheGetUserGroup(id int) (group string, err error) {
  50. if !common.RedisEnabled {
  51. return GetUserGroup(id)
  52. }
  53. group, err = common.RedisGet(fmt.Sprintf("user_group:%d", id))
  54. if err != nil {
  55. group, err = GetUserGroup(id)
  56. if err != nil {
  57. return "", err
  58. }
  59. err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
  60. if err != nil {
  61. common.SysError("Redis set user group error: " + err.Error())
  62. }
  63. }
  64. return group, err
  65. }
  66. func CacheGetUsername(id int) (username string, err error) {
  67. if !common.RedisEnabled {
  68. return GetUsernameById(id)
  69. }
  70. username, err = common.RedisGet(fmt.Sprintf("user_name:%d", id))
  71. if err != nil {
  72. username, err = GetUsernameById(id)
  73. if err != nil {
  74. return "", err
  75. }
  76. err = common.RedisSet(fmt.Sprintf("user_name:%d", id), username, time.Duration(UserId2GroupCacheSeconds)*time.Second)
  77. if err != nil {
  78. common.SysError("Redis set user group error: " + err.Error())
  79. }
  80. }
  81. return username, err
  82. }
  83. func CacheGetUserQuota(id int) (quota int, err error) {
  84. if !common.RedisEnabled {
  85. return GetUserQuota(id)
  86. }
  87. quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id))
  88. if err != nil {
  89. quota, err = GetUserQuota(id)
  90. if err != nil {
  91. return 0, err
  92. }
  93. err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
  94. if err != nil {
  95. common.SysError("Redis set user quota error: " + err.Error())
  96. }
  97. return quota, err
  98. }
  99. quota, err = strconv.Atoi(quotaString)
  100. return quota, err
  101. }
  102. func CacheUpdateUserQuota(id int) error {
  103. if !common.RedisEnabled {
  104. return nil
  105. }
  106. quota, err := GetUserQuota(id)
  107. if err != nil {
  108. return err
  109. }
  110. err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
  111. return err
  112. }
  113. func CacheDecreaseUserQuota(id int, quota int) error {
  114. if !common.RedisEnabled {
  115. return nil
  116. }
  117. err := common.RedisDecrease(fmt.Sprintf("user_quota:%d", id), int64(quota))
  118. return err
  119. }
  120. func CacheIsUserEnabled(userId int) (bool, error) {
  121. if !common.RedisEnabled {
  122. return IsUserEnabled(userId)
  123. }
  124. enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId))
  125. if err == nil {
  126. return enabled == "1", nil
  127. }
  128. userEnabled, err := IsUserEnabled(userId)
  129. if err != nil {
  130. return false, err
  131. }
  132. enabled = "0"
  133. if userEnabled {
  134. enabled = "1"
  135. }
  136. err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
  137. if err != nil {
  138. common.SysError("Redis set user enabled error: " + err.Error())
  139. }
  140. return userEnabled, err
  141. }
  142. var group2model2channels map[string]map[string][]*Channel
  143. var channelsIDM map[int]*Channel
  144. var channelSyncLock sync.RWMutex
  145. func InitChannelCache() {
  146. newChannelId2channel := make(map[int]*Channel)
  147. var channels []*Channel
  148. DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels)
  149. for _, channel := range channels {
  150. newChannelId2channel[channel.Id] = channel
  151. }
  152. var abilities []*Ability
  153. DB.Find(&abilities)
  154. groups := make(map[string]bool)
  155. for _, ability := range abilities {
  156. groups[ability.Group] = true
  157. }
  158. newGroup2model2channels := make(map[string]map[string][]*Channel)
  159. newChannelsIDM := make(map[int]*Channel)
  160. for group := range groups {
  161. newGroup2model2channels[group] = make(map[string][]*Channel)
  162. }
  163. for _, channel := range channels {
  164. newChannelsIDM[channel.Id] = channel
  165. groups := strings.Split(channel.Group, ",")
  166. for _, group := range groups {
  167. models := strings.Split(channel.Models, ",")
  168. for _, model := range models {
  169. if _, ok := newGroup2model2channels[group][model]; !ok {
  170. newGroup2model2channels[group][model] = make([]*Channel, 0)
  171. }
  172. newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel)
  173. }
  174. }
  175. }
  176. // sort by priority
  177. for group, model2channels := range newGroup2model2channels {
  178. for model, channels := range model2channels {
  179. sort.Slice(channels, func(i, j int) bool {
  180. return channels[i].GetPriority() > channels[j].GetPriority()
  181. })
  182. newGroup2model2channels[group][model] = channels
  183. }
  184. }
  185. channelSyncLock.Lock()
  186. group2model2channels = newGroup2model2channels
  187. channelsIDM = newChannelsIDM
  188. channelSyncLock.Unlock()
  189. common.SysLog("channels synced from database")
  190. }
  191. func SyncChannelCache(frequency int) {
  192. for {
  193. time.Sleep(time.Duration(frequency) * time.Second)
  194. common.SysLog("syncing channels from database")
  195. InitChannelCache()
  196. }
  197. }
  198. func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
  199. if strings.HasPrefix(model, "gpt-4-gizmo") {
  200. model = "gpt-4-gizmo-*"
  201. }
  202. // if memory cache is disabled, get channel directly from database
  203. if !common.MemoryCacheEnabled {
  204. return GetRandomSatisfiedChannel(group, model)
  205. }
  206. channelSyncLock.RLock()
  207. defer channelSyncLock.RUnlock()
  208. channels := group2model2channels[group][model]
  209. if len(channels) == 0 {
  210. return nil, errors.New("channel not found")
  211. }
  212. endIdx := len(channels)
  213. // choose by priority
  214. firstChannel := channels[0]
  215. if firstChannel.GetPriority() > 0 {
  216. for i := range channels {
  217. if channels[i].GetPriority() != firstChannel.GetPriority() {
  218. endIdx = i
  219. break
  220. }
  221. }
  222. }
  223. // Calculate the total weight of all channels up to endIdx
  224. totalWeight := 0
  225. for _, channel := range channels[:endIdx] {
  226. totalWeight += channel.GetWeight()
  227. }
  228. if totalWeight == 0 {
  229. // If all weights are 0, select a channel randomly
  230. return channels[rand.Intn(endIdx)], nil
  231. }
  232. // Generate a random value in the range [0, totalWeight)
  233. randomWeight := rand.Intn(totalWeight)
  234. // Find a channel based on its weight
  235. for _, channel := range channels[:endIdx] {
  236. randomWeight -= channel.GetWeight()
  237. if randomWeight <= 0 {
  238. return channel, nil
  239. }
  240. }
  241. // return null if no channel is not found
  242. return nil, errors.New("channel not found")
  243. }
  244. func CacheGetChannel(id int) (*Channel, error) {
  245. if !common.MemoryCacheEnabled {
  246. return GetChannelById(id, true)
  247. }
  248. channelSyncLock.RLock()
  249. defer channelSyncLock.RUnlock()
  250. c, ok := channelsIDM[id]
  251. if !ok {
  252. return nil, errors.New(fmt.Sprintf("当前渠道# %d,已不存在", id))
  253. }
  254. return c, nil
  255. }