cache.go 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164
  1. package model
  2. import (
  3. "context"
  4. "encoding"
  5. "errors"
  6. "math/rand/v2"
  7. "slices"
  8. "sort"
  9. "sync"
  10. "sync/atomic"
  11. "time"
  12. "github.com/bytedance/sonic"
  13. "github.com/labring/aiproxy/core/common"
  14. "github.com/labring/aiproxy/core/common/config"
  15. "github.com/labring/aiproxy/core/common/conv"
  16. "github.com/labring/aiproxy/core/common/notify"
  17. "github.com/maruel/natural"
  18. "github.com/redis/go-redis/v9"
  19. log "github.com/sirupsen/logrus"
  20. )
  21. const (
  22. SyncFrequency = time.Minute * 3
  23. TokenCacheKey = "token:%s"
  24. GroupCacheKey = "group:%s"
  25. GroupModelTPMKey = "group:%s:model_tpm"
  26. )
  27. var (
  28. _ encoding.BinaryMarshaler = (*redisStringSlice)(nil)
  29. _ redis.Scanner = (*redisStringSlice)(nil)
  30. )
  31. type redisStringSlice []string
  32. func (r *redisStringSlice) ScanRedis(value string) error {
  33. return sonic.Unmarshal(conv.StringToBytes(value), r)
  34. }
  35. func (r redisStringSlice) MarshalBinary() ([]byte, error) {
  36. return sonic.Marshal(r)
  37. }
  38. type redisTime time.Time
  39. var (
  40. _ redis.Scanner = (*redisTime)(nil)
  41. _ encoding.BinaryMarshaler = (*redisTime)(nil)
  42. )
  43. func (t *redisTime) ScanRedis(value string) error {
  44. return (*time.Time)(t).UnmarshalBinary(conv.StringToBytes(value))
  45. }
  46. func (t redisTime) MarshalBinary() ([]byte, error) {
  47. return time.Time(t).MarshalBinary()
  48. }
  49. type TokenCache struct {
  50. Group string `json:"group" redis:"g"`
  51. Key string `json:"-" redis:"-"`
  52. Name string `json:"name" redis:"n"`
  53. Subnets redisStringSlice `json:"subnets" redis:"s"`
  54. Models redisStringSlice `json:"models" redis:"m"`
  55. ID int `json:"id" redis:"i"`
  56. Status int `json:"status" redis:"st"`
  57. UsedAmount float64 `json:"used_amount" redis:"u"`
  58. // Quota system
  59. Quota float64 `json:"quota" redis:"q"`
  60. PeriodQuota float64 `json:"period_quota" redis:"pq"`
  61. PeriodType string `json:"period_type" redis:"pt"`
  62. PeriodLastUpdateTime redisTime `json:"period_last_update_time" redis:"plut"`
  63. PeriodLastUpdateAmount float64 `json:"period_last_update_amount" redis:"plua"`
  64. availableSets []string
  65. modelsBySet map[string][]string
  66. }
  67. func (t *TokenCache) SetAvailableSets(availableSets []string) {
  68. t.availableSets = availableSets
  69. }
  70. func (t *TokenCache) SetModelsBySet(modelsBySet map[string][]string) {
  71. t.modelsBySet = modelsBySet
  72. }
  73. func (t *TokenCache) ContainsModel(model string) bool {
  74. if len(t.Models) != 0 {
  75. if !slices.Contains(t.Models, model) {
  76. return false
  77. }
  78. }
  79. return containsModel(model, t.availableSets, t.modelsBySet)
  80. }
  81. func containsModel(model string, sets []string, modelsBySet map[string][]string) bool {
  82. for _, set := range sets {
  83. if slices.Contains(modelsBySet[set], model) {
  84. return true
  85. }
  86. }
  87. return false
  88. }
  89. func (t *TokenCache) Range(fn func(model string) bool) {
  90. ranged := make(map[string]struct{})
  91. if len(t.Models) != 0 {
  92. for _, model := range t.Models {
  93. if _, ok := ranged[model]; !ok && containsModel(model, t.availableSets, t.modelsBySet) {
  94. if !fn(model) {
  95. return
  96. }
  97. }
  98. ranged[model] = struct{}{}
  99. }
  100. return
  101. }
  102. for _, set := range t.availableSets {
  103. for _, model := range t.modelsBySet[set] {
  104. if _, ok := ranged[model]; !ok {
  105. if !fn(model) {
  106. return
  107. }
  108. }
  109. ranged[model] = struct{}{}
  110. }
  111. }
  112. }
  113. func (t *Token) ToTokenCache() *TokenCache {
  114. return &TokenCache{
  115. ID: t.ID,
  116. Group: t.GroupID,
  117. Key: t.Key,
  118. Name: string(t.Name),
  119. Models: t.Models,
  120. Subnets: t.Subnets,
  121. Status: t.Status,
  122. UsedAmount: t.UsedAmount,
  123. Quota: t.Quota,
  124. PeriodQuota: t.PeriodQuota,
  125. PeriodType: string(t.PeriodType),
  126. PeriodLastUpdateTime: redisTime(t.PeriodLastUpdateTime),
  127. PeriodLastUpdateAmount: t.PeriodLastUpdateAmount,
  128. }
  129. }
  130. func CacheDeleteToken(key string) error {
  131. if !common.RedisEnabled {
  132. return nil
  133. }
  134. return common.RDB.Del(context.Background(), common.RedisKeyf(TokenCacheKey, key)).Err()
  135. }
  136. func CacheSetToken(token *TokenCache) error {
  137. if !common.RedisEnabled {
  138. return nil
  139. }
  140. key := common.RedisKeyf(TokenCacheKey, token.Key)
  141. pipe := common.RDB.Pipeline()
  142. pipe.HSet(context.Background(), key, token)
  143. expireTime := SyncFrequency + time.Duration(rand.Int64N(60)-30)*time.Second
  144. pipe.Expire(context.Background(), key, expireTime)
  145. _, err := pipe.Exec(context.Background())
  146. return err
  147. }
  148. func CacheGetTokenByKey(key string) (*TokenCache, error) {
  149. if !common.RedisEnabled {
  150. token, err := GetTokenByKey(key)
  151. if err != nil {
  152. return nil, err
  153. }
  154. return token.ToTokenCache(), nil
  155. }
  156. cacheKey := common.RedisKeyf(TokenCacheKey, key)
  157. tokenCache := &TokenCache{}
  158. err := common.RDB.HGetAll(context.Background(), cacheKey).Scan(tokenCache)
  159. if err == nil && tokenCache.ID != 0 {
  160. tokenCache.Key = key
  161. return tokenCache, nil
  162. } else if err != nil && !errors.Is(err, redis.Nil) {
  163. log.Errorf("get token (%s) from redis error: %s", key, err.Error())
  164. }
  165. token, err := GetTokenByKey(key)
  166. if err != nil {
  167. return nil, err
  168. }
  169. tc := token.ToTokenCache()
  170. if err := CacheSetToken(tc); err != nil {
  171. log.Error("redis set token error: " + err.Error())
  172. }
  173. return tc, nil
  174. }
  175. var updateTokenUsedAmountOnlyIncreaseScript = redis.NewScript(`
  176. local used_amount = redis.call("HGet", KEYS[1], "ua")
  177. if used_amount == false then
  178. return redis.status_reply("ok")
  179. end
  180. if ARGV[1] < used_amount then
  181. return redis.status_reply("ok")
  182. end
  183. redis.call("HSet", KEYS[1], "ua", ARGV[1])
  184. return redis.status_reply("ok")
  185. `)
  186. func CacheUpdateTokenUsedAmountOnlyIncrease(key string, amount float64) error {
  187. if !common.RedisEnabled {
  188. return nil
  189. }
  190. return updateTokenUsedAmountOnlyIncreaseScript.Run(context.Background(), common.RDB, []string{common.RedisKeyf(TokenCacheKey, key)}, amount).
  191. Err()
  192. }
  193. // CacheResetTokenPeriodUsage resets period usage in cache and updates period last update info
  194. func CacheResetTokenPeriodUsage(
  195. key string,
  196. periodLastUpdateTime time.Time,
  197. periodLastUpdateAmount float64,
  198. ) error {
  199. if !common.RedisEnabled {
  200. return nil
  201. }
  202. cacheKey := common.RedisKeyf(TokenCacheKey, key)
  203. pipe := common.RDB.Pipeline()
  204. periodLastUpdateTimeBytes, _ := periodLastUpdateTime.MarshalBinary()
  205. pipe.HSet(context.Background(), cacheKey, "plut", periodLastUpdateTimeBytes)
  206. pipe.HSet(context.Background(), cacheKey, "plua", periodLastUpdateAmount)
  207. _, err := pipe.Exec(context.Background())
  208. return err
  209. }
  210. var updateTokenNameScript = redis.NewScript(`
  211. if redis.call("HExists", KEYS[1], "n") then
  212. redis.call("HSet", KEYS[1], "n", ARGV[1])
  213. end
  214. return redis.status_reply("ok")
  215. `)
  216. func CacheUpdateTokenName(key, name string) error {
  217. if !common.RedisEnabled {
  218. return nil
  219. }
  220. return updateTokenNameScript.Run(context.Background(), common.RDB, []string{common.RedisKeyf(TokenCacheKey, key)}, name).
  221. Err()
  222. }
  223. var updateTokenStatusScript = redis.NewScript(`
  224. if redis.call("HExists", KEYS[1], "st") then
  225. redis.call("HSet", KEYS[1], "st", ARGV[1])
  226. end
  227. return redis.status_reply("ok")
  228. `)
  229. func CacheUpdateTokenStatus(key string, status int) error {
  230. if !common.RedisEnabled {
  231. return nil
  232. }
  233. return updateTokenStatusScript.Run(context.Background(), common.RDB, []string{common.RedisKeyf(TokenCacheKey, key)}, status).
  234. Err()
  235. }
  236. type redisMap[K comparable, V any] map[K]V
  237. var (
  238. _ redis.Scanner = (*redisMap[string, any])(nil)
  239. _ encoding.BinaryMarshaler = (*redisMap[string, any])(nil)
  240. )
  241. func (r *redisMap[K, V]) ScanRedis(value string) error {
  242. return sonic.UnmarshalString(value, r)
  243. }
  244. func (r redisMap[K, V]) MarshalBinary() ([]byte, error) {
  245. return sonic.Marshal(r)
  246. }
  247. type (
  248. redisGroupModelConfigMap = redisMap[string, GroupModelConfig]
  249. )
  250. type GroupCache struct {
  251. ID string `json:"-" redis:"-"`
  252. Status int `json:"status" redis:"st"`
  253. UsedAmount float64 `json:"used_amount" redis:"ua"`
  254. RPMRatio float64 `json:"rpm_ratio" redis:"rpm_r"`
  255. TPMRatio float64 `json:"tpm_ratio" redis:"tpm_r"`
  256. AvailableSets redisStringSlice `json:"available_sets" redis:"ass"`
  257. ModelConfigs redisGroupModelConfigMap `json:"model_configs" redis:"mc"`
  258. BalanceAlertEnabled bool `json:"balance_alert_enabled" redis:"bae"`
  259. BalanceAlertThreshold float64 `json:"balance_alert_threshold" redis:"bat"`
  260. }
  261. func (g *GroupCache) GetAvailableSets() []string {
  262. if len(g.AvailableSets) == 0 {
  263. return []string{ChannelDefaultSet}
  264. }
  265. return g.AvailableSets
  266. }
  267. func (g *Group) ToGroupCache() *GroupCache {
  268. modelConfigs := make(redisGroupModelConfigMap, len(g.GroupModelConfigs))
  269. for _, modelConfig := range g.GroupModelConfigs {
  270. modelConfigs[modelConfig.Model] = modelConfig
  271. }
  272. return &GroupCache{
  273. ID: g.ID,
  274. Status: g.Status,
  275. UsedAmount: g.UsedAmount,
  276. RPMRatio: g.RPMRatio,
  277. TPMRatio: g.TPMRatio,
  278. AvailableSets: g.AvailableSets,
  279. ModelConfigs: modelConfigs,
  280. BalanceAlertEnabled: g.BalanceAlertEnabled,
  281. BalanceAlertThreshold: g.BalanceAlertThreshold,
  282. }
  283. }
  284. func CacheDeleteGroup(id string) error {
  285. if !common.RedisEnabled {
  286. return nil
  287. }
  288. return common.RDB.Del(context.Background(), common.RedisKeyf(GroupCacheKey, id)).Err()
  289. }
  290. var updateGroupRPMRatioScript = redis.NewScript(`
  291. if redis.call("HExists", KEYS[1], "rpm_r") then
  292. redis.call("HSet", KEYS[1], "rpm_r", ARGV[1])
  293. end
  294. return redis.status_reply("ok")
  295. `)
  296. func CacheUpdateGroupRPMRatio(id string, rpmRatio float64) error {
  297. if !common.RedisEnabled {
  298. return nil
  299. }
  300. return updateGroupRPMRatioScript.Run(context.Background(), common.RDB, []string{common.RedisKeyf(GroupCacheKey, id)}, rpmRatio).
  301. Err()
  302. }
  303. var updateGroupTPMRatioScript = redis.NewScript(`
  304. if redis.call("HExists", KEYS[1], "tpm_r") then
  305. redis.call("HSet", KEYS[1], "tpm_r", ARGV[1])
  306. end
  307. return redis.status_reply("ok")
  308. `)
  309. func CacheUpdateGroupTPMRatio(id string, tpmRatio float64) error {
  310. if !common.RedisEnabled {
  311. return nil
  312. }
  313. return updateGroupTPMRatioScript.Run(context.Background(), common.RDB, []string{common.RedisKeyf(GroupCacheKey, id)}, tpmRatio).
  314. Err()
  315. }
  316. var updateGroupStatusScript = redis.NewScript(`
  317. if redis.call("HExists", KEYS[1], "st") then
  318. redis.call("HSet", KEYS[1], "st", ARGV[1])
  319. end
  320. return redis.status_reply("ok")
  321. `)
  322. func CacheUpdateGroupStatus(id string, status int) error {
  323. if !common.RedisEnabled {
  324. return nil
  325. }
  326. return updateGroupStatusScript.Run(context.Background(), common.RDB, []string{common.RedisKeyf(GroupCacheKey, id)}, status).
  327. Err()
  328. }
  329. func CacheSetGroup(group *GroupCache) error {
  330. if !common.RedisEnabled {
  331. return nil
  332. }
  333. key := common.RedisKeyf(GroupCacheKey, group.ID)
  334. pipe := common.RDB.Pipeline()
  335. pipe.HSet(context.Background(), key, group)
  336. expireTime := SyncFrequency + time.Duration(rand.Int64N(60)-30)*time.Second
  337. pipe.Expire(context.Background(), key, expireTime)
  338. _, err := pipe.Exec(context.Background())
  339. return err
  340. }
  341. func CacheGetGroup(id string) (*GroupCache, error) {
  342. if !common.RedisEnabled {
  343. group, err := GetGroupByID(id, true)
  344. if err != nil {
  345. return nil, err
  346. }
  347. return group.ToGroupCache(), nil
  348. }
  349. cacheKey := common.RedisKeyf(GroupCacheKey, id)
  350. groupCache := &GroupCache{}
  351. err := common.RDB.HGetAll(context.Background(), cacheKey).Scan(groupCache)
  352. if err == nil && groupCache.Status != 0 {
  353. groupCache.ID = id
  354. return groupCache, nil
  355. } else if err != nil && !errors.Is(err, redis.Nil) {
  356. log.Errorf("get group (%s) from redis error: %s", id, err.Error())
  357. }
  358. group, err := GetGroupByID(id, true)
  359. if err != nil {
  360. return nil, err
  361. }
  362. gc := group.ToGroupCache()
  363. if err := CacheSetGroup(gc); err != nil {
  364. log.Error("redis set group error: " + err.Error())
  365. }
  366. return gc, nil
  367. }
  368. var updateGroupUsedAmountOnlyIncreaseScript = redis.NewScript(`
  369. local used_amount = redis.call("HGet", KEYS[1], "ua")
  370. if used_amount == false then
  371. return redis.status_reply("ok")
  372. end
  373. if ARGV[1] < used_amount then
  374. return redis.status_reply("ok")
  375. end
  376. redis.call("HSet", KEYS[1], "ua", ARGV[1])
  377. return redis.status_reply("ok")
  378. `)
  379. func CacheUpdateGroupUsedAmountOnlyIncrease(id string, amount float64) error {
  380. if !common.RedisEnabled {
  381. return nil
  382. }
  383. return updateGroupUsedAmountOnlyIncreaseScript.Run(context.Background(), common.RDB, []string{common.RedisKeyf(GroupCacheKey, id)}, amount).
  384. Err()
  385. }
  386. type GroupMCPCache struct {
  387. ID string `json:"id" redis:"i"`
  388. GroupID string `json:"group_id" redis:"g"`
  389. Status GroupMCPStatus `json:"status" redis:"s"`
  390. Type GroupMCPType `json:"type" redis:"t"`
  391. ProxyConfig *GroupMCPProxyConfig `json:"proxy_config" redis:"pc"`
  392. OpenAPIConfig *MCPOpenAPIConfig `json:"openapi_config" redis:"oc"`
  393. }
  394. func (g *GroupMCP) ToGroupMCPCache() *GroupMCPCache {
  395. return &GroupMCPCache{
  396. ID: g.ID,
  397. GroupID: g.GroupID,
  398. Status: g.Status,
  399. Type: g.Type,
  400. ProxyConfig: g.ProxyConfig,
  401. OpenAPIConfig: g.OpenAPIConfig,
  402. }
  403. }
  404. const (
  405. GroupMCPCacheKey = "group_mcp:%s:%s" // group_id:mcp_id
  406. )
  407. func CacheDeleteGroupMCP(groupID, mcpID string) error {
  408. if !common.RedisEnabled {
  409. return nil
  410. }
  411. return common.RDB.Del(context.Background(), common.RedisKeyf(GroupMCPCacheKey, groupID, mcpID)).
  412. Err()
  413. }
  414. func CacheSetGroupMCP(groupMCP *GroupMCPCache) error {
  415. if !common.RedisEnabled {
  416. return nil
  417. }
  418. key := common.RedisKeyf(GroupMCPCacheKey, groupMCP.GroupID, groupMCP.ID)
  419. pipe := common.RDB.Pipeline()
  420. pipe.HSet(context.Background(), key, groupMCP)
  421. expireTime := SyncFrequency + time.Duration(rand.Int64N(60)-30)*time.Second
  422. pipe.Expire(context.Background(), key, expireTime)
  423. _, err := pipe.Exec(context.Background())
  424. return err
  425. }
  426. func CacheGetGroupMCP(groupID, mcpID string) (*GroupMCPCache, error) {
  427. if !common.RedisEnabled {
  428. groupMCP, err := GetGroupMCPByID(mcpID, groupID)
  429. if err != nil {
  430. return nil, err
  431. }
  432. return groupMCP.ToGroupMCPCache(), nil
  433. }
  434. cacheKey := common.RedisKeyf(GroupMCPCacheKey, groupID, mcpID)
  435. groupMCPCache := &GroupMCPCache{}
  436. err := common.RDB.HGetAll(context.Background(), cacheKey).Scan(groupMCPCache)
  437. if err == nil && groupMCPCache.ID != "" {
  438. return groupMCPCache, nil
  439. } else if err != nil && !errors.Is(err, redis.Nil) {
  440. log.Errorf("get group mcp (%s:%s) from redis error: %s", groupID, mcpID, err.Error())
  441. }
  442. groupMCP, err := GetGroupMCPByID(mcpID, groupID)
  443. if err != nil {
  444. return nil, err
  445. }
  446. gmc := groupMCP.ToGroupMCPCache()
  447. if err := CacheSetGroupMCP(gmc); err != nil {
  448. log.Error("redis set group mcp error: " + err.Error())
  449. }
  450. return gmc, nil
  451. }
  452. var updateGroupMCPStatusScript = redis.NewScript(`
  453. if redis.call("HExists", KEYS[1], "s") then
  454. redis.call("HSet", KEYS[1], "s", ARGV[1])
  455. end
  456. return redis.status_reply("ok")
  457. `)
  458. func CacheUpdateGroupMCPStatus(groupID, mcpID string, status GroupMCPStatus) error {
  459. if !common.RedisEnabled {
  460. return nil
  461. }
  462. return updateGroupMCPStatusScript.Run(context.Background(), common.RDB, []string{common.RedisKeyf(GroupMCPCacheKey, groupID, mcpID)}, status).
  463. Err()
  464. }
  465. type PublicMCPCache struct {
  466. ID string `json:"id" redis:"i"`
  467. Status PublicMCPStatus `json:"status" redis:"s"`
  468. Type PublicMCPType `json:"type" redis:"t"`
  469. Price MCPPrice `json:"price" redis:"p"`
  470. ProxyConfig *PublicMCPProxyConfig `json:"proxy_config" redis:"pc"`
  471. OpenAPIConfig *MCPOpenAPIConfig `json:"openapi_config" redis:"oc"`
  472. EmbedConfig *MCPEmbeddingConfig `json:"embed_config" redis:"ec"`
  473. }
  474. func (p *PublicMCP) ToPublicMCPCache() *PublicMCPCache {
  475. return &PublicMCPCache{
  476. ID: p.ID,
  477. Status: p.Status,
  478. Type: p.Type,
  479. Price: p.Price,
  480. ProxyConfig: p.ProxyConfig,
  481. OpenAPIConfig: p.OpenAPIConfig,
  482. EmbedConfig: p.EmbedConfig,
  483. }
  484. }
  485. const (
  486. PublicMCPCacheKey = "public_mcp:%s" // mcp_id
  487. )
  488. func CacheDeletePublicMCP(mcpID string) error {
  489. if !common.RedisEnabled {
  490. return nil
  491. }
  492. return common.RDB.Del(context.Background(), common.RedisKeyf(PublicMCPCacheKey, mcpID)).Err()
  493. }
  494. func CacheSetPublicMCP(publicMCP *PublicMCPCache) error {
  495. if !common.RedisEnabled {
  496. return nil
  497. }
  498. key := common.RedisKeyf(PublicMCPCacheKey, publicMCP.ID)
  499. pipe := common.RDB.Pipeline()
  500. pipe.HSet(context.Background(), key, publicMCP)
  501. expireTime := SyncFrequency + time.Duration(rand.Int64N(60)-30)*time.Second
  502. pipe.Expire(context.Background(), key, expireTime)
  503. _, err := pipe.Exec(context.Background())
  504. return err
  505. }
  506. func CacheGetPublicMCP(mcpID string) (*PublicMCPCache, error) {
  507. if !common.RedisEnabled {
  508. publicMCP, err := GetPublicMCPByID(mcpID)
  509. if err != nil {
  510. return nil, err
  511. }
  512. return publicMCP.ToPublicMCPCache(), nil
  513. }
  514. cacheKey := common.RedisKeyf(PublicMCPCacheKey, mcpID)
  515. publicMCPCache := &PublicMCPCache{}
  516. err := common.RDB.HGetAll(context.Background(), cacheKey).Scan(publicMCPCache)
  517. if err == nil && publicMCPCache.ID != "" {
  518. return publicMCPCache, nil
  519. } else if err != nil && !errors.Is(err, redis.Nil) {
  520. log.Errorf("get public mcp (%s) from redis error: %s", mcpID, err.Error())
  521. }
  522. publicMCP, err := GetPublicMCPByID(mcpID)
  523. if err != nil {
  524. return nil, err
  525. }
  526. pmc := publicMCP.ToPublicMCPCache()
  527. if err := CacheSetPublicMCP(pmc); err != nil {
  528. log.Error("redis set public mcp error: " + err.Error())
  529. }
  530. return pmc, nil
  531. }
  532. var updatePublicMCPStatusScript = redis.NewScript(`
  533. if redis.call("HExists", KEYS[1], "s") then
  534. redis.call("HSet", KEYS[1], "s", ARGV[1])
  535. end
  536. return redis.status_reply("ok")
  537. `)
  538. func CacheUpdatePublicMCPStatus(mcpID string, status PublicMCPStatus) error {
  539. if !common.RedisEnabled {
  540. return nil
  541. }
  542. return updatePublicMCPStatusScript.Run(context.Background(), common.RDB, []string{common.RedisKeyf(PublicMCPCacheKey, mcpID)}, status).
  543. Err()
  544. }
  545. const (
  546. PublicMCPReusingParamCacheKey = "public_mcp_param:%s:%s" // mcp_id:group_id
  547. )
  548. type PublicMCPReusingParamCache struct {
  549. MCPID string `json:"mcp_id" redis:"m"`
  550. GroupID string `json:"group_id" redis:"g"`
  551. Params redisMap[string, string] `json:"params" redis:"p"`
  552. }
  553. func (p *PublicMCPReusingParam) ToPublicMCPReusingParamCache() PublicMCPReusingParamCache {
  554. return PublicMCPReusingParamCache{
  555. MCPID: p.MCPID,
  556. GroupID: p.GroupID,
  557. Params: p.Params,
  558. }
  559. }
  560. func CacheDeletePublicMCPReusingParam(mcpID, groupID string) error {
  561. if !common.RedisEnabled {
  562. return nil
  563. }
  564. return common.RDB.Del(context.Background(), common.RedisKeyf(PublicMCPReusingParamCacheKey, mcpID, groupID)).
  565. Err()
  566. }
  567. func CacheSetPublicMCPReusingParam(param PublicMCPReusingParamCache) error {
  568. if !common.RedisEnabled {
  569. return nil
  570. }
  571. key := common.RedisKeyf(PublicMCPReusingParamCacheKey, param.MCPID, param.GroupID)
  572. pipe := common.RDB.Pipeline()
  573. pipe.HSet(context.Background(), key, param)
  574. expireTime := SyncFrequency + time.Duration(rand.Int64N(60)-30)*time.Second
  575. pipe.Expire(context.Background(), key, expireTime)
  576. _, err := pipe.Exec(context.Background())
  577. return err
  578. }
  579. func CacheGetPublicMCPReusingParam(mcpID, groupID string) (PublicMCPReusingParamCache, error) {
  580. if groupID == "" {
  581. return PublicMCPReusingParamCache{
  582. MCPID: mcpID,
  583. GroupID: groupID,
  584. Params: make(map[string]string),
  585. }, nil
  586. }
  587. if !common.RedisEnabled {
  588. param, err := GetPublicMCPReusingParam(mcpID, groupID)
  589. if err != nil {
  590. return PublicMCPReusingParamCache{}, err
  591. }
  592. return param.ToPublicMCPReusingParamCache(), nil
  593. }
  594. cacheKey := common.RedisKeyf(PublicMCPReusingParamCacheKey, mcpID, groupID)
  595. paramCache := PublicMCPReusingParamCache{}
  596. err := common.RDB.HGetAll(context.Background(), cacheKey).Scan(&paramCache)
  597. if err == nil && paramCache.MCPID != "" {
  598. return paramCache, nil
  599. } else if err != nil && !errors.Is(err, redis.Nil) {
  600. log.Errorf("get public mcp reusing param (%s:%s) from redis error: %s", mcpID, groupID, err.Error())
  601. }
  602. param, err := GetPublicMCPReusingParam(mcpID, groupID)
  603. if err != nil {
  604. return PublicMCPReusingParamCache{}, err
  605. }
  606. prc := param.ToPublicMCPReusingParamCache()
  607. if err := CacheSetPublicMCPReusingParam(prc); err != nil {
  608. log.Error("redis set public mcp reusing param error: " + err.Error())
  609. }
  610. return prc, nil
  611. }
  612. const (
  613. StoreCacheKey = "storev2:%s:%d:%s" // store_id
  614. )
  615. type StoreCache struct {
  616. ID string `json:"id" redis:"i"`
  617. GroupID string `json:"group_id" redis:"g"`
  618. TokenID int `json:"token_id" redis:"t"`
  619. ChannelID int `json:"channel_id" redis:"c"`
  620. Model string `json:"model" redis:"m"`
  621. ExpiresAt time.Time `json:"expires_at" redis:"e"`
  622. }
  623. func (s *StoreV2) ToStoreCache() *StoreCache {
  624. return &StoreCache{
  625. ID: s.ID,
  626. GroupID: s.GroupID,
  627. TokenID: s.TokenID,
  628. ChannelID: s.ChannelID,
  629. Model: s.Model,
  630. ExpiresAt: s.ExpiresAt,
  631. }
  632. }
  633. func CacheSetStore(store *StoreCache) error {
  634. if !common.RedisEnabled {
  635. return nil
  636. }
  637. key := common.RedisKeyf(StoreCacheKey, store.GroupID, store.TokenID, store.ID)
  638. pipe := common.RDB.Pipeline()
  639. pipe.HSet(context.Background(), key, store)
  640. expireTime := SyncFrequency + time.Duration(rand.Int64N(60)-30)*time.Second
  641. pipe.Expire(context.Background(), key, expireTime)
  642. _, err := pipe.Exec(context.Background())
  643. return err
  644. }
  645. func CacheGetStore(group string, tokenID int, id string) (*StoreCache, error) {
  646. if !common.RedisEnabled {
  647. store, err := GetStore(group, tokenID, id)
  648. if err != nil {
  649. return nil, err
  650. }
  651. return store.ToStoreCache(), nil
  652. }
  653. cacheKey := common.RedisKeyf(StoreCacheKey, group, tokenID, id)
  654. storeCache := &StoreCache{}
  655. err := common.RDB.HGetAll(context.Background(), cacheKey).Scan(storeCache)
  656. if err == nil && storeCache.ID != "" {
  657. return storeCache, nil
  658. }
  659. store, err := GetStore(group, tokenID, id)
  660. if err != nil {
  661. return nil, err
  662. }
  663. sc := store.ToStoreCache()
  664. if err := CacheSetStore(sc); err != nil {
  665. log.Error("redis set store error: " + err.Error())
  666. }
  667. return sc, nil
  668. }
  669. type ModelConfigCache interface {
  670. GetModelConfig(model string) (ModelConfig, bool)
  671. }
  672. // read-only cache
  673. //
  674. type ModelCaches struct {
  675. ModelConfig ModelConfigCache
  676. // map[set][]model
  677. EnabledModelsBySet map[string][]string
  678. // map[set][]modelconfig
  679. EnabledModelConfigsBySet map[string][]ModelConfig
  680. // map[model]modelconfig
  681. EnabledModelConfigsMap map[string]ModelConfig
  682. // map[set]map[model][]channel
  683. EnabledModel2ChannelsBySet map[string]map[string][]*Channel
  684. // map[set]map[model][]channel
  685. DisabledModel2ChannelsBySet map[string]map[string][]*Channel
  686. }
  687. var modelCaches atomic.Pointer[ModelCaches]
  688. func init() {
  689. modelCaches.Store(new(ModelCaches))
  690. }
  691. func LoadModelCaches() *ModelCaches {
  692. return modelCaches.Load()
  693. }
  694. // InitModelConfigAndChannelCache initializes the channel cache from database
  695. func InitModelConfigAndChannelCache() error {
  696. modelConfig, err := initializeModelConfigCache()
  697. if err != nil {
  698. return err
  699. }
  700. // Load enabled channels from database
  701. enabledChannels, err := LoadEnabledChannels()
  702. if err != nil {
  703. return err
  704. }
  705. // Build model to channels map by set
  706. enabledModel2ChannelsBySet := buildModelToChannelsBySetMap(enabledChannels)
  707. // Sort channels by priority within each set
  708. sortChannelsByPriorityBySet(enabledModel2ChannelsBySet)
  709. // Build enabled models and configs by set
  710. enabledModelsBySet, enabledModelConfigsBySet, enabledModelConfigsMap := buildEnabledModelsBySet(
  711. enabledModel2ChannelsBySet,
  712. modelConfig,
  713. )
  714. // Load disabled channels
  715. disabledChannels, err := LoadDisabledChannels()
  716. if err != nil {
  717. return err
  718. }
  719. // Build disabled model to channels map by set
  720. disabledModel2ChannelsBySet := buildModelToChannelsBySetMap(disabledChannels)
  721. // Update global cache atomically
  722. modelCaches.Store(&ModelCaches{
  723. ModelConfig: modelConfig,
  724. EnabledModelsBySet: enabledModelsBySet,
  725. EnabledModelConfigsBySet: enabledModelConfigsBySet,
  726. EnabledModelConfigsMap: enabledModelConfigsMap,
  727. EnabledModel2ChannelsBySet: enabledModel2ChannelsBySet,
  728. DisabledModel2ChannelsBySet: disabledModel2ChannelsBySet,
  729. })
  730. return nil
  731. }
  732. func LoadEnabledChannels() ([]*Channel, error) {
  733. var channels []*Channel
  734. err := DB.Where("status = ?", ChannelStatusEnabled).Find(&channels).Error
  735. if err != nil {
  736. return nil, err
  737. }
  738. for _, channel := range channels {
  739. initializeChannelModels(channel)
  740. initializeChannelModelMapping(channel)
  741. }
  742. return channels, nil
  743. }
  744. func LoadDisabledChannels() ([]*Channel, error) {
  745. var channels []*Channel
  746. err := DB.Where("status = ?", ChannelStatusDisabled).Find(&channels).Error
  747. if err != nil {
  748. return nil, err
  749. }
  750. for _, channel := range channels {
  751. initializeChannelModels(channel)
  752. initializeChannelModelMapping(channel)
  753. }
  754. return channels, nil
  755. }
  756. func LoadChannels() ([]*Channel, error) {
  757. var channels []*Channel
  758. err := DB.Find(&channels).Error
  759. if err != nil {
  760. return nil, err
  761. }
  762. for _, channel := range channels {
  763. initializeChannelModels(channel)
  764. initializeChannelModelMapping(channel)
  765. }
  766. return channels, nil
  767. }
  768. func LoadChannelByID(id int) (*Channel, error) {
  769. var channel Channel
  770. err := DB.First(&channel, id).Error
  771. if err != nil {
  772. return nil, err
  773. }
  774. initializeChannelModels(&channel)
  775. initializeChannelModelMapping(&channel)
  776. return &channel, nil
  777. }
  778. var _ ModelConfigCache = (*modelConfigMapCache)(nil)
  779. type modelConfigMapCache struct {
  780. modelConfigMap map[string]ModelConfig
  781. }
  782. func (m *modelConfigMapCache) GetModelConfig(model string) (ModelConfig, bool) {
  783. config, ok := m.modelConfigMap[model]
  784. return config, ok
  785. }
  786. var _ ModelConfigCache = (*disabledModelConfigCache)(nil)
  787. type disabledModelConfigCache struct {
  788. modelConfigs ModelConfigCache
  789. }
  790. func (d *disabledModelConfigCache) GetModelConfig(model string) (ModelConfig, bool) {
  791. if config, ok := d.modelConfigs.GetModelConfig(model); ok {
  792. return config, true
  793. }
  794. return NewDefaultModelConfig(model), true
  795. }
  796. func initializeModelConfigCache() (ModelConfigCache, error) {
  797. modelConfigs, err := GetAllModelConfigs()
  798. if err != nil {
  799. return nil, err
  800. }
  801. newModelConfigMap := make(map[string]ModelConfig)
  802. for _, modelConfig := range modelConfigs {
  803. newModelConfigMap[modelConfig.Model] = modelConfig
  804. }
  805. configs := &modelConfigMapCache{modelConfigMap: newModelConfigMap}
  806. if config.DisableModelConfig {
  807. return &disabledModelConfigCache{modelConfigs: configs}, nil
  808. }
  809. return configs, nil
  810. }
  811. func initializeChannelModels(channel *Channel) {
  812. if len(channel.Models) == 0 {
  813. channel.Models = config.GetDefaultChannelModels()[int(channel.Type)]
  814. return
  815. }
  816. findedModels, missingModels, err := GetModelConfigWithModels(channel.Models)
  817. if err != nil {
  818. return
  819. }
  820. if len(missingModels) > 0 {
  821. slices.Sort(missingModels)
  822. log.Errorf("model config not found: %v", missingModels)
  823. }
  824. slices.Sort(findedModels)
  825. channel.Models = findedModels
  826. }
  827. func initializeChannelModelMapping(channel *Channel) {
  828. if len(channel.ModelMapping) == 0 {
  829. channel.ModelMapping = config.GetDefaultChannelModelMapping()[int(channel.Type)]
  830. }
  831. }
  832. func buildModelToChannelsBySetMap(channels []*Channel) map[string]map[string][]*Channel {
  833. modelMapBySet := make(map[string]map[string][]*Channel)
  834. for _, channel := range channels {
  835. sets := channel.GetSets()
  836. for _, set := range sets {
  837. if _, ok := modelMapBySet[set]; !ok {
  838. modelMapBySet[set] = make(map[string][]*Channel)
  839. }
  840. for _, model := range channel.Models {
  841. modelMapBySet[set][model] = append(modelMapBySet[set][model], channel)
  842. }
  843. }
  844. }
  845. return modelMapBySet
  846. }
  847. func sortChannelsByPriorityBySet(modelMapBySet map[string]map[string][]*Channel) {
  848. for _, modelMap := range modelMapBySet {
  849. for _, channels := range modelMap {
  850. sort.Slice(channels, func(i, j int) bool {
  851. return channels[i].GetPriority() > channels[j].GetPriority()
  852. })
  853. }
  854. }
  855. }
  856. func buildEnabledModelsBySet(
  857. modelMapBySet map[string]map[string][]*Channel,
  858. modelConfigCache ModelConfigCache,
  859. ) (
  860. map[string][]string,
  861. map[string][]ModelConfig,
  862. map[string]ModelConfig,
  863. ) {
  864. modelsBySet := make(map[string][]string)
  865. modelConfigsBySet := make(map[string][]ModelConfig)
  866. modelConfigsMap := make(map[string]ModelConfig)
  867. for set, modelMap := range modelMapBySet {
  868. models := make([]string, 0)
  869. configs := make([]ModelConfig, 0)
  870. appended := make(map[string]struct{})
  871. for model := range modelMap {
  872. if _, ok := appended[model]; ok {
  873. continue
  874. }
  875. if config, ok := modelConfigCache.GetModelConfig(model); ok {
  876. models = append(models, model)
  877. configs = append(configs, config)
  878. appended[model] = struct{}{}
  879. modelConfigsMap[model] = config
  880. }
  881. }
  882. slices.Sort(models)
  883. slices.SortStableFunc(configs, SortModelConfigsFunc)
  884. modelsBySet[set] = models
  885. modelConfigsBySet[set] = configs
  886. }
  887. return modelsBySet, modelConfigsBySet, modelConfigsMap
  888. }
  889. func SortModelConfigsFunc(i, j ModelConfig) int {
  890. if i.Owner != j.Owner {
  891. if natural.Less(string(i.Owner), string(j.Owner)) {
  892. return -1
  893. }
  894. return 1
  895. }
  896. if i.Type != j.Type {
  897. if i.Type < j.Type {
  898. return -1
  899. }
  900. return 1
  901. }
  902. if i.Model == j.Model {
  903. return 0
  904. }
  905. if natural.Less(i.Model, j.Model) {
  906. return -1
  907. }
  908. return 1
  909. }
  910. func SyncModelConfigAndChannelCache(
  911. ctx context.Context,
  912. wg *sync.WaitGroup,
  913. frequency time.Duration,
  914. ) {
  915. defer wg.Done()
  916. ticker := time.NewTicker(frequency)
  917. defer ticker.Stop()
  918. for {
  919. select {
  920. case <-ctx.Done():
  921. return
  922. case <-ticker.C:
  923. err := InitModelConfigAndChannelCache()
  924. if err != nil {
  925. notify.ErrorThrottle(
  926. "syncModelChannel",
  927. time.Minute,
  928. "failed to sync channels",
  929. err.Error(),
  930. )
  931. }
  932. }
  933. }
  934. }