cache.go 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. package model
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "math/rand"
  7. "one-api/common"
  8. "sort"
  9. "strconv"
  10. "strings"
  11. "sync"
  12. "time"
  13. )
  14. var (
  15. TokenCacheSeconds = common.SyncFrequency
  16. UserId2GroupCacheSeconds = common.SyncFrequency
  17. UserId2QuotaCacheSeconds = common.SyncFrequency
  18. UserId2StatusCacheSeconds = common.SyncFrequency
  19. )
  20. // 仅用于定时同步缓存
  21. var token2UserId = make(map[string]int)
  22. var token2UserIdLock sync.RWMutex
  23. func cacheSetToken(token *Token) error {
  24. if !common.RedisEnabled {
  25. return token.SelectUpdate()
  26. }
  27. jsonBytes, err := json.Marshal(token)
  28. if err != nil {
  29. return err
  30. }
  31. err = common.RedisSet(fmt.Sprintf("token:%s", token.Key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second)
  32. if err != nil {
  33. common.SysError(fmt.Sprintf("failed to set token %s to redis: %s", token.Key, err.Error()))
  34. return err
  35. }
  36. token2UserIdLock.Lock()
  37. defer token2UserIdLock.Unlock()
  38. token2UserId[token.Key] = token.UserId
  39. return nil
  40. }
  41. // CacheGetTokenByKey 从缓存中获取 token 并续期时间,如果缓存中不存在,则从数据库中获取
  42. func CacheGetTokenByKey(key string) (*Token, error) {
  43. if !common.RedisEnabled {
  44. return GetTokenByKey(key)
  45. }
  46. var token *Token
  47. tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key))
  48. if err != nil {
  49. // 如果缓存中不存在,则从数据库中获取
  50. token, err = GetTokenByKey(key)
  51. if err != nil {
  52. return nil, err
  53. }
  54. err = cacheSetToken(token)
  55. return token, nil
  56. }
  57. // 如果缓存中存在,则续期时间
  58. err = common.RedisExpire(fmt.Sprintf("token:%s", key), time.Duration(TokenCacheSeconds)*time.Second)
  59. err = json.Unmarshal([]byte(tokenObjectString), &token)
  60. return token, err
  61. }
  62. func SyncTokenCache(frequency int) {
  63. for {
  64. time.Sleep(time.Duration(frequency) * time.Second)
  65. common.SysLog("syncing tokens from database")
  66. token2UserIdLock.Lock()
  67. // 从token2UserId中获取所有的key
  68. var copyToken2UserId = make(map[string]int)
  69. for s, i := range token2UserId {
  70. copyToken2UserId[s] = i
  71. }
  72. token2UserId = make(map[string]int)
  73. token2UserIdLock.Unlock()
  74. for key := range copyToken2UserId {
  75. token, err := GetTokenByKey(key)
  76. if err != nil {
  77. // 如果数据库中不存在,则删除缓存
  78. common.SysError(fmt.Sprintf("failed to get token %s from database: %s", key, err.Error()))
  79. //delete redis
  80. err := common.RedisDel(fmt.Sprintf("token:%s", key))
  81. if err != nil {
  82. common.SysError(fmt.Sprintf("failed to delete token %s from redis: %s", key, err.Error()))
  83. }
  84. } else {
  85. // 如果数据库中存在,先检查redis
  86. _, err := common.RedisGet(fmt.Sprintf("token:%s", key))
  87. if err != nil {
  88. // 如果redis中不存在,则跳过
  89. continue
  90. }
  91. err = cacheSetToken(token)
  92. if err != nil {
  93. common.SysError(fmt.Sprintf("failed to update token %s to redis: %s", key, err.Error()))
  94. }
  95. }
  96. }
  97. }
  98. }
  99. func CacheGetUserGroup(id int) (group string, err error) {
  100. if !common.RedisEnabled {
  101. return GetUserGroup(id)
  102. }
  103. group, err = common.RedisGet(fmt.Sprintf("user_group:%d", id))
  104. if err != nil {
  105. group, err = GetUserGroup(id)
  106. if err != nil {
  107. return "", err
  108. }
  109. err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
  110. if err != nil {
  111. common.SysError("Redis set user group error: " + err.Error())
  112. }
  113. }
  114. return group, err
  115. }
  116. func CacheGetUsername(id int) (username string, err error) {
  117. if !common.RedisEnabled {
  118. return GetUsernameById(id)
  119. }
  120. username, err = common.RedisGet(fmt.Sprintf("user_name:%d", id))
  121. if err != nil {
  122. username, err = GetUsernameById(id)
  123. if err != nil {
  124. return "", err
  125. }
  126. err = common.RedisSet(fmt.Sprintf("user_name:%d", id), username, time.Duration(UserId2GroupCacheSeconds)*time.Second)
  127. if err != nil {
  128. common.SysError("Redis set user group error: " + err.Error())
  129. }
  130. }
  131. return username, err
  132. }
  133. func CacheGetUserQuota(id int) (quota int, err error) {
  134. if !common.RedisEnabled {
  135. return GetUserQuota(id)
  136. }
  137. quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id))
  138. if err != nil {
  139. quota, err = GetUserQuota(id)
  140. if err != nil {
  141. return 0, err
  142. }
  143. err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
  144. if err != nil {
  145. common.SysError("Redis set user quota error: " + err.Error())
  146. }
  147. return quota, err
  148. }
  149. quota, err = strconv.Atoi(quotaString)
  150. return quota, err
  151. }
  152. func CacheUpdateUserQuota(id int) error {
  153. if !common.RedisEnabled {
  154. return nil
  155. }
  156. quota, err := GetUserQuota(id)
  157. if err != nil {
  158. return err
  159. }
  160. err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
  161. return err
  162. }
  163. func CacheDecreaseUserQuota(id int, quota int) error {
  164. if !common.RedisEnabled {
  165. return nil
  166. }
  167. err := common.RedisDecrease(fmt.Sprintf("user_quota:%d", id), int64(quota))
  168. return err
  169. }
  170. func CacheIsUserEnabled(userId int) (bool, error) {
  171. if !common.RedisEnabled {
  172. return IsUserEnabled(userId)
  173. }
  174. enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId))
  175. if err == nil {
  176. return enabled == "1", nil
  177. }
  178. userEnabled, err := IsUserEnabled(userId)
  179. if err != nil {
  180. return false, err
  181. }
  182. enabled = "0"
  183. if userEnabled {
  184. enabled = "1"
  185. }
  186. err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
  187. if err != nil {
  188. common.SysError("Redis set user enabled error: " + err.Error())
  189. }
  190. return userEnabled, err
  191. }
  192. var group2model2channels map[string]map[string][]*Channel
  193. var channelsIDM map[int]*Channel
  194. var channelSyncLock sync.RWMutex
  195. func InitChannelCache() {
  196. newChannelId2channel := make(map[int]*Channel)
  197. var channels []*Channel
  198. DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels)
  199. for _, channel := range channels {
  200. newChannelId2channel[channel.Id] = channel
  201. }
  202. var abilities []*Ability
  203. DB.Find(&abilities)
  204. groups := make(map[string]bool)
  205. for _, ability := range abilities {
  206. groups[ability.Group] = true
  207. }
  208. newGroup2model2channels := make(map[string]map[string][]*Channel)
  209. newChannelsIDM := make(map[int]*Channel)
  210. for group := range groups {
  211. newGroup2model2channels[group] = make(map[string][]*Channel)
  212. }
  213. for _, channel := range channels {
  214. newChannelsIDM[channel.Id] = channel
  215. groups := strings.Split(channel.Group, ",")
  216. for _, group := range groups {
  217. models := strings.Split(channel.Models, ",")
  218. for _, model := range models {
  219. if _, ok := newGroup2model2channels[group][model]; !ok {
  220. newGroup2model2channels[group][model] = make([]*Channel, 0)
  221. }
  222. newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel)
  223. }
  224. }
  225. }
  226. // sort by priority
  227. for group, model2channels := range newGroup2model2channels {
  228. for model, channels := range model2channels {
  229. sort.Slice(channels, func(i, j int) bool {
  230. return channels[i].GetPriority() > channels[j].GetPriority()
  231. })
  232. newGroup2model2channels[group][model] = channels
  233. }
  234. }
  235. channelSyncLock.Lock()
  236. group2model2channels = newGroup2model2channels
  237. channelsIDM = newChannelsIDM
  238. channelSyncLock.Unlock()
  239. common.SysLog("channels synced from database")
  240. }
  241. func SyncChannelCache(frequency int) {
  242. for {
  243. time.Sleep(time.Duration(frequency) * time.Second)
  244. common.SysLog("syncing channels from database")
  245. InitChannelCache()
  246. }
  247. }
  248. func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
  249. if strings.HasPrefix(model, "gpt-4-gizmo") {
  250. model = "gpt-4-gizmo-*"
  251. }
  252. // if memory cache is disabled, get channel directly from database
  253. if !common.MemoryCacheEnabled {
  254. return GetRandomSatisfiedChannel(group, model)
  255. }
  256. channelSyncLock.RLock()
  257. defer channelSyncLock.RUnlock()
  258. channels := group2model2channels[group][model]
  259. if len(channels) == 0 {
  260. return nil, errors.New("channel not found")
  261. }
  262. endIdx := len(channels)
  263. // choose by priority
  264. firstChannel := channels[0]
  265. if firstChannel.GetPriority() > 0 {
  266. for i := range channels {
  267. if channels[i].GetPriority() != firstChannel.GetPriority() {
  268. endIdx = i
  269. break
  270. }
  271. }
  272. }
  273. // 平滑系数
  274. smoothingFactor := 10
  275. // Calculate the total weight of all channels up to endIdx
  276. totalWeight := 0
  277. for _, channel := range channels[:endIdx] {
  278. totalWeight += channel.GetWeight() + smoothingFactor
  279. }
  280. //if totalWeight == 0 {
  281. // // If all weights are 0, select a channel randomly
  282. // return channels[rand.Intn(endIdx)], nil
  283. //}
  284. // Generate a random value in the range [0, totalWeight)
  285. randomWeight := rand.Intn(totalWeight)
  286. // Find a channel based on its weight
  287. for _, channel := range channels[:endIdx] {
  288. randomWeight -= channel.GetWeight() + smoothingFactor
  289. if randomWeight < 0 {
  290. return channel, nil
  291. }
  292. }
  293. // return null if no channel is not found
  294. return nil, errors.New("channel not found")
  295. }
  296. func CacheGetChannel(id int) (*Channel, error) {
  297. if !common.MemoryCacheEnabled {
  298. return GetChannelById(id, true)
  299. }
  300. channelSyncLock.RLock()
  301. defer channelSyncLock.RUnlock()
  302. c, ok := channelsIDM[id]
  303. if !ok {
  304. return nil, errors.New(fmt.Sprintf("当前渠道# %d,已不存在", id))
  305. }
  306. return c, nil
  307. }