cache.go 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185
  1. package model
  2. import (
  3. "context"
  4. "encoding"
  5. "errors"
  6. "math/rand/v2"
  7. "slices"
  8. "sort"
  9. "strings"
  10. "sync"
  11. "sync/atomic"
  12. "time"
  13. "github.com/bytedance/sonic"
  14. "github.com/labring/aiproxy/core/common"
  15. "github.com/labring/aiproxy/core/common/config"
  16. "github.com/labring/aiproxy/core/common/conv"
  17. "github.com/labring/aiproxy/core/common/notify"
  18. "github.com/maruel/natural"
  19. "github.com/redis/go-redis/v9"
  20. log "github.com/sirupsen/logrus"
  21. )
  22. const (
  23. SyncFrequency = time.Minute * 3
  24. TokenCacheKey = "token:%s"
  25. GroupCacheKey = "group:%s"
  26. GroupModelTPMKey = "group:%s:model_tpm"
  27. )
  28. var (
  29. _ encoding.BinaryMarshaler = (*redisStringSlice)(nil)
  30. _ redis.Scanner = (*redisStringSlice)(nil)
  31. )
  32. type redisStringSlice []string
  33. func (r *redisStringSlice) ScanRedis(value string) error {
  34. return sonic.Unmarshal(conv.StringToBytes(value), r)
  35. }
  36. func (r redisStringSlice) MarshalBinary() ([]byte, error) {
  37. return sonic.Marshal(r)
  38. }
  39. type redisTime time.Time
  40. var (
  41. _ redis.Scanner = (*redisTime)(nil)
  42. _ encoding.BinaryMarshaler = (*redisTime)(nil)
  43. )
  44. func (t *redisTime) ScanRedis(value string) error {
  45. return (*time.Time)(t).UnmarshalBinary(conv.StringToBytes(value))
  46. }
  47. func (t redisTime) MarshalBinary() ([]byte, error) {
  48. return time.Time(t).MarshalBinary()
  49. }
  50. type TokenCache struct {
  51. Group string `json:"group" redis:"g"`
  52. Key string `json:"-" redis:"-"`
  53. Name string `json:"name" redis:"n"`
  54. Subnets redisStringSlice `json:"subnets" redis:"s"`
  55. Models redisStringSlice `json:"models" redis:"m"`
  56. ID int `json:"id" redis:"i"`
  57. Status int `json:"status" redis:"st"`
  58. UsedAmount float64 `json:"used_amount" redis:"u"`
  59. // Quota system
  60. Quota float64 `json:"quota" redis:"q"`
  61. PeriodQuota float64 `json:"period_quota" redis:"pq"`
  62. PeriodType string `json:"period_type" redis:"pt"`
  63. PeriodLastUpdateTime redisTime `json:"period_last_update_time" redis:"plut"`
  64. PeriodLastUpdateAmount float64 `json:"period_last_update_amount" redis:"plua"`
  65. availableSets []string
  66. modelsBySet map[string][]string
  67. }
  68. func (t *TokenCache) SetAvailableSets(availableSets []string) {
  69. t.availableSets = availableSets
  70. }
  71. func (t *TokenCache) SetModelsBySet(modelsBySet map[string][]string) {
  72. t.modelsBySet = modelsBySet
  73. }
  74. func (t *TokenCache) FindModel(model string) string {
  75. var findModel string
  76. if len(t.Models) != 0 {
  77. if !slices.ContainsFunc(t.Models, func(e string) bool {
  78. ok := strings.EqualFold(e, model)
  79. if ok {
  80. findModel = e
  81. }
  82. return ok
  83. }) {
  84. return findModel
  85. }
  86. }
  87. return containsModel(model, t.availableSets, t.modelsBySet)
  88. }
  89. func containsModel(model string, sets []string, modelsBySet map[string][]string) string {
  90. var findModel string
  91. for _, set := range sets {
  92. if slices.ContainsFunc(modelsBySet[set], func(e string) bool {
  93. ok := strings.EqualFold(e, model)
  94. if ok {
  95. findModel = e
  96. }
  97. return ok
  98. }) {
  99. return findModel
  100. }
  101. }
  102. return findModel
  103. }
  104. func (t *TokenCache) Range(fn func(model string) bool) {
  105. ranged := make(map[string]struct{})
  106. if len(t.Models) != 0 {
  107. for _, model := range t.Models {
  108. if _, ok := ranged[model]; ok {
  109. continue
  110. }
  111. model = containsModel(model, t.availableSets, t.modelsBySet)
  112. if model == "" {
  113. continue
  114. }
  115. ranged[model] = struct{}{}
  116. if !fn(model) {
  117. return
  118. }
  119. }
  120. return
  121. }
  122. for _, set := range t.availableSets {
  123. for _, model := range t.modelsBySet[set] {
  124. if _, ok := ranged[model]; !ok {
  125. if !fn(model) {
  126. return
  127. }
  128. }
  129. ranged[model] = struct{}{}
  130. }
  131. }
  132. }
  133. func (t *Token) ToTokenCache() *TokenCache {
  134. return &TokenCache{
  135. ID: t.ID,
  136. Group: t.GroupID,
  137. Key: t.Key,
  138. Name: string(t.Name),
  139. Models: t.Models,
  140. Subnets: t.Subnets,
  141. Status: t.Status,
  142. UsedAmount: t.UsedAmount,
  143. Quota: t.Quota,
  144. PeriodQuota: t.PeriodQuota,
  145. PeriodType: string(t.PeriodType),
  146. PeriodLastUpdateTime: redisTime(t.PeriodLastUpdateTime),
  147. PeriodLastUpdateAmount: t.PeriodLastUpdateAmount,
  148. }
  149. }
  150. func CacheDeleteToken(key string) error {
  151. if !common.RedisEnabled {
  152. return nil
  153. }
  154. return common.RDB.Del(context.Background(), common.RedisKeyf(TokenCacheKey, key)).Err()
  155. }
  156. func CacheSetToken(token *TokenCache) error {
  157. if !common.RedisEnabled {
  158. return nil
  159. }
  160. key := common.RedisKeyf(TokenCacheKey, token.Key)
  161. pipe := common.RDB.Pipeline()
  162. pipe.HSet(context.Background(), key, token)
  163. expireTime := SyncFrequency + time.Duration(rand.Int64N(60)-30)*time.Second
  164. pipe.Expire(context.Background(), key, expireTime)
  165. _, err := pipe.Exec(context.Background())
  166. return err
  167. }
  168. func CacheGetTokenByKey(key string) (*TokenCache, error) {
  169. if !common.RedisEnabled {
  170. token, err := GetTokenByKey(key)
  171. if err != nil {
  172. return nil, err
  173. }
  174. return token.ToTokenCache(), nil
  175. }
  176. cacheKey := common.RedisKeyf(TokenCacheKey, key)
  177. tokenCache := &TokenCache{}
  178. err := common.RDB.HGetAll(context.Background(), cacheKey).Scan(tokenCache)
  179. if err == nil && tokenCache.ID != 0 {
  180. tokenCache.Key = key
  181. return tokenCache, nil
  182. } else if err != nil && !errors.Is(err, redis.Nil) {
  183. log.Errorf("get token (%s) from redis error: %s", key, err.Error())
  184. }
  185. token, err := GetTokenByKey(key)
  186. if err != nil {
  187. return nil, err
  188. }
  189. tc := token.ToTokenCache()
  190. if err := CacheSetToken(tc); err != nil {
  191. log.Error("redis set token error: " + err.Error())
  192. }
  193. return tc, nil
  194. }
  195. var updateTokenUsedAmountOnlyIncreaseScript = redis.NewScript(`
  196. local used_amount = redis.call("HGet", KEYS[1], "ua")
  197. if used_amount == false then
  198. return redis.status_reply("ok")
  199. end
  200. if ARGV[1] < used_amount then
  201. return redis.status_reply("ok")
  202. end
  203. redis.call("HSet", KEYS[1], "ua", ARGV[1])
  204. return redis.status_reply("ok")
  205. `)
  206. func CacheUpdateTokenUsedAmountOnlyIncrease(key string, amount float64) error {
  207. if !common.RedisEnabled {
  208. return nil
  209. }
  210. return updateTokenUsedAmountOnlyIncreaseScript.Run(context.Background(), common.RDB, []string{common.RedisKeyf(TokenCacheKey, key)}, amount).
  211. Err()
  212. }
  213. // CacheResetTokenPeriodUsage resets period usage in cache and updates period last update info
  214. func CacheResetTokenPeriodUsage(
  215. key string,
  216. periodLastUpdateTime time.Time,
  217. periodLastUpdateAmount float64,
  218. ) error {
  219. if !common.RedisEnabled {
  220. return nil
  221. }
  222. cacheKey := common.RedisKeyf(TokenCacheKey, key)
  223. pipe := common.RDB.Pipeline()
  224. periodLastUpdateTimeBytes, _ := periodLastUpdateTime.MarshalBinary()
  225. pipe.HSet(context.Background(), cacheKey, "plut", periodLastUpdateTimeBytes)
  226. pipe.HSet(context.Background(), cacheKey, "plua", periodLastUpdateAmount)
  227. _, err := pipe.Exec(context.Background())
  228. return err
  229. }
  230. var updateTokenNameScript = redis.NewScript(`
  231. if redis.call("HExists", KEYS[1], "n") then
  232. redis.call("HSet", KEYS[1], "n", ARGV[1])
  233. end
  234. return redis.status_reply("ok")
  235. `)
  236. func CacheUpdateTokenName(key, name string) error {
  237. if !common.RedisEnabled {
  238. return nil
  239. }
  240. return updateTokenNameScript.Run(context.Background(), common.RDB, []string{common.RedisKeyf(TokenCacheKey, key)}, name).
  241. Err()
  242. }
  243. var updateTokenStatusScript = redis.NewScript(`
  244. if redis.call("HExists", KEYS[1], "st") then
  245. redis.call("HSet", KEYS[1], "st", ARGV[1])
  246. end
  247. return redis.status_reply("ok")
  248. `)
  249. func CacheUpdateTokenStatus(key string, status int) error {
  250. if !common.RedisEnabled {
  251. return nil
  252. }
  253. return updateTokenStatusScript.Run(context.Background(), common.RDB, []string{common.RedisKeyf(TokenCacheKey, key)}, status).
  254. Err()
  255. }
  256. type redisMap[K comparable, V any] map[K]V
  257. var (
  258. _ redis.Scanner = (*redisMap[string, any])(nil)
  259. _ encoding.BinaryMarshaler = (*redisMap[string, any])(nil)
  260. )
  261. func (r *redisMap[K, V]) ScanRedis(value string) error {
  262. return sonic.UnmarshalString(value, r)
  263. }
  264. func (r redisMap[K, V]) MarshalBinary() ([]byte, error) {
  265. return sonic.Marshal(r)
  266. }
  267. type (
  268. redisGroupModelConfigMap = redisMap[string, GroupModelConfig]
  269. )
  270. type GroupCache struct {
  271. ID string `json:"-" redis:"-"`
  272. Status int `json:"status" redis:"st"`
  273. UsedAmount float64 `json:"used_amount" redis:"ua"`
  274. RPMRatio float64 `json:"rpm_ratio" redis:"rpm_r"`
  275. TPMRatio float64 `json:"tpm_ratio" redis:"tpm_r"`
  276. AvailableSets redisStringSlice `json:"available_sets" redis:"ass"`
  277. ModelConfigs redisGroupModelConfigMap `json:"model_configs" redis:"mc"`
  278. BalanceAlertEnabled bool `json:"balance_alert_enabled" redis:"bae"`
  279. BalanceAlertThreshold float64 `json:"balance_alert_threshold" redis:"bat"`
  280. }
  281. func (g *GroupCache) GetAvailableSets() []string {
  282. if len(g.AvailableSets) == 0 {
  283. return []string{ChannelDefaultSet}
  284. }
  285. return g.AvailableSets
  286. }
  287. func (g *Group) ToGroupCache() *GroupCache {
  288. modelConfigs := make(redisGroupModelConfigMap, len(g.GroupModelConfigs))
  289. for _, modelConfig := range g.GroupModelConfigs {
  290. modelConfigs[modelConfig.Model] = modelConfig
  291. }
  292. return &GroupCache{
  293. ID: g.ID,
  294. Status: g.Status,
  295. UsedAmount: g.UsedAmount,
  296. RPMRatio: g.RPMRatio,
  297. TPMRatio: g.TPMRatio,
  298. AvailableSets: g.AvailableSets,
  299. ModelConfigs: modelConfigs,
  300. BalanceAlertEnabled: g.BalanceAlertEnabled,
  301. BalanceAlertThreshold: g.BalanceAlertThreshold,
  302. }
  303. }
  304. func CacheDeleteGroup(id string) error {
  305. if !common.RedisEnabled {
  306. return nil
  307. }
  308. return common.RDB.Del(context.Background(), common.RedisKeyf(GroupCacheKey, id)).Err()
  309. }
  310. var updateGroupRPMRatioScript = redis.NewScript(`
  311. if redis.call("HExists", KEYS[1], "rpm_r") then
  312. redis.call("HSet", KEYS[1], "rpm_r", ARGV[1])
  313. end
  314. return redis.status_reply("ok")
  315. `)
  316. func CacheUpdateGroupRPMRatio(id string, rpmRatio float64) error {
  317. if !common.RedisEnabled {
  318. return nil
  319. }
  320. return updateGroupRPMRatioScript.Run(context.Background(), common.RDB, []string{common.RedisKeyf(GroupCacheKey, id)}, rpmRatio).
  321. Err()
  322. }
  323. var updateGroupTPMRatioScript = redis.NewScript(`
  324. if redis.call("HExists", KEYS[1], "tpm_r") then
  325. redis.call("HSet", KEYS[1], "tpm_r", ARGV[1])
  326. end
  327. return redis.status_reply("ok")
  328. `)
  329. func CacheUpdateGroupTPMRatio(id string, tpmRatio float64) error {
  330. if !common.RedisEnabled {
  331. return nil
  332. }
  333. return updateGroupTPMRatioScript.Run(context.Background(), common.RDB, []string{common.RedisKeyf(GroupCacheKey, id)}, tpmRatio).
  334. Err()
  335. }
  336. var updateGroupStatusScript = redis.NewScript(`
  337. if redis.call("HExists", KEYS[1], "st") then
  338. redis.call("HSet", KEYS[1], "st", ARGV[1])
  339. end
  340. return redis.status_reply("ok")
  341. `)
  342. func CacheUpdateGroupStatus(id string, status int) error {
  343. if !common.RedisEnabled {
  344. return nil
  345. }
  346. return updateGroupStatusScript.Run(context.Background(), common.RDB, []string{common.RedisKeyf(GroupCacheKey, id)}, status).
  347. Err()
  348. }
  349. func CacheSetGroup(group *GroupCache) error {
  350. if !common.RedisEnabled {
  351. return nil
  352. }
  353. key := common.RedisKeyf(GroupCacheKey, group.ID)
  354. pipe := common.RDB.Pipeline()
  355. pipe.HSet(context.Background(), key, group)
  356. expireTime := SyncFrequency + time.Duration(rand.Int64N(60)-30)*time.Second
  357. pipe.Expire(context.Background(), key, expireTime)
  358. _, err := pipe.Exec(context.Background())
  359. return err
  360. }
  361. func CacheGetGroup(id string) (*GroupCache, error) {
  362. if !common.RedisEnabled {
  363. group, err := GetGroupByID(id, true)
  364. if err != nil {
  365. return nil, err
  366. }
  367. return group.ToGroupCache(), nil
  368. }
  369. cacheKey := common.RedisKeyf(GroupCacheKey, id)
  370. groupCache := &GroupCache{}
  371. err := common.RDB.HGetAll(context.Background(), cacheKey).Scan(groupCache)
  372. if err == nil && groupCache.Status != 0 {
  373. groupCache.ID = id
  374. return groupCache, nil
  375. } else if err != nil && !errors.Is(err, redis.Nil) {
  376. log.Errorf("get group (%s) from redis error: %s", id, err.Error())
  377. }
  378. group, err := GetGroupByID(id, true)
  379. if err != nil {
  380. return nil, err
  381. }
  382. gc := group.ToGroupCache()
  383. if err := CacheSetGroup(gc); err != nil {
  384. log.Error("redis set group error: " + err.Error())
  385. }
  386. return gc, nil
  387. }
  388. var updateGroupUsedAmountOnlyIncreaseScript = redis.NewScript(`
  389. local used_amount = redis.call("HGet", KEYS[1], "ua")
  390. if used_amount == false then
  391. return redis.status_reply("ok")
  392. end
  393. if ARGV[1] < used_amount then
  394. return redis.status_reply("ok")
  395. end
  396. redis.call("HSet", KEYS[1], "ua", ARGV[1])
  397. return redis.status_reply("ok")
  398. `)
  399. func CacheUpdateGroupUsedAmountOnlyIncrease(id string, amount float64) error {
  400. if !common.RedisEnabled {
  401. return nil
  402. }
  403. return updateGroupUsedAmountOnlyIncreaseScript.Run(context.Background(), common.RDB, []string{common.RedisKeyf(GroupCacheKey, id)}, amount).
  404. Err()
  405. }
  406. type GroupMCPCache struct {
  407. ID string `json:"id" redis:"i"`
  408. GroupID string `json:"group_id" redis:"g"`
  409. Status GroupMCPStatus `json:"status" redis:"s"`
  410. Type GroupMCPType `json:"type" redis:"t"`
  411. ProxyConfig *GroupMCPProxyConfig `json:"proxy_config" redis:"pc"`
  412. OpenAPIConfig *MCPOpenAPIConfig `json:"openapi_config" redis:"oc"`
  413. }
  414. func (g *GroupMCP) ToGroupMCPCache() *GroupMCPCache {
  415. return &GroupMCPCache{
  416. ID: g.ID,
  417. GroupID: g.GroupID,
  418. Status: g.Status,
  419. Type: g.Type,
  420. ProxyConfig: g.ProxyConfig,
  421. OpenAPIConfig: g.OpenAPIConfig,
  422. }
  423. }
  424. const (
  425. GroupMCPCacheKey = "group_mcp:%s:%s" // group_id:mcp_id
  426. )
  427. func CacheDeleteGroupMCP(groupID, mcpID string) error {
  428. if !common.RedisEnabled {
  429. return nil
  430. }
  431. return common.RDB.Del(context.Background(), common.RedisKeyf(GroupMCPCacheKey, groupID, mcpID)).
  432. Err()
  433. }
  434. func CacheSetGroupMCP(groupMCP *GroupMCPCache) error {
  435. if !common.RedisEnabled {
  436. return nil
  437. }
  438. key := common.RedisKeyf(GroupMCPCacheKey, groupMCP.GroupID, groupMCP.ID)
  439. pipe := common.RDB.Pipeline()
  440. pipe.HSet(context.Background(), key, groupMCP)
  441. expireTime := SyncFrequency + time.Duration(rand.Int64N(60)-30)*time.Second
  442. pipe.Expire(context.Background(), key, expireTime)
  443. _, err := pipe.Exec(context.Background())
  444. return err
  445. }
  446. func CacheGetGroupMCP(groupID, mcpID string) (*GroupMCPCache, error) {
  447. if !common.RedisEnabled {
  448. groupMCP, err := GetGroupMCPByID(mcpID, groupID)
  449. if err != nil {
  450. return nil, err
  451. }
  452. return groupMCP.ToGroupMCPCache(), nil
  453. }
  454. cacheKey := common.RedisKeyf(GroupMCPCacheKey, groupID, mcpID)
  455. groupMCPCache := &GroupMCPCache{}
  456. err := common.RDB.HGetAll(context.Background(), cacheKey).Scan(groupMCPCache)
  457. if err == nil && groupMCPCache.ID != "" {
  458. return groupMCPCache, nil
  459. } else if err != nil && !errors.Is(err, redis.Nil) {
  460. log.Errorf("get group mcp (%s:%s) from redis error: %s", groupID, mcpID, err.Error())
  461. }
  462. groupMCP, err := GetGroupMCPByID(mcpID, groupID)
  463. if err != nil {
  464. return nil, err
  465. }
  466. gmc := groupMCP.ToGroupMCPCache()
  467. if err := CacheSetGroupMCP(gmc); err != nil {
  468. log.Error("redis set group mcp error: " + err.Error())
  469. }
  470. return gmc, nil
  471. }
  472. var updateGroupMCPStatusScript = redis.NewScript(`
  473. if redis.call("HExists", KEYS[1], "s") then
  474. redis.call("HSet", KEYS[1], "s", ARGV[1])
  475. end
  476. return redis.status_reply("ok")
  477. `)
  478. func CacheUpdateGroupMCPStatus(groupID, mcpID string, status GroupMCPStatus) error {
  479. if !common.RedisEnabled {
  480. return nil
  481. }
  482. return updateGroupMCPStatusScript.Run(context.Background(), common.RDB, []string{common.RedisKeyf(GroupMCPCacheKey, groupID, mcpID)}, status).
  483. Err()
  484. }
  485. type PublicMCPCache struct {
  486. ID string `json:"id" redis:"i"`
  487. Status PublicMCPStatus `json:"status" redis:"s"`
  488. Type PublicMCPType `json:"type" redis:"t"`
  489. Price MCPPrice `json:"price" redis:"p"`
  490. ProxyConfig *PublicMCPProxyConfig `json:"proxy_config" redis:"pc"`
  491. OpenAPIConfig *MCPOpenAPIConfig `json:"openapi_config" redis:"oc"`
  492. EmbedConfig *MCPEmbeddingConfig `json:"embed_config" redis:"ec"`
  493. }
  494. func (p *PublicMCP) ToPublicMCPCache() *PublicMCPCache {
  495. return &PublicMCPCache{
  496. ID: p.ID,
  497. Status: p.Status,
  498. Type: p.Type,
  499. Price: p.Price,
  500. ProxyConfig: p.ProxyConfig,
  501. OpenAPIConfig: p.OpenAPIConfig,
  502. EmbedConfig: p.EmbedConfig,
  503. }
  504. }
  505. const (
  506. PublicMCPCacheKey = "public_mcp:%s" // mcp_id
  507. )
  508. func CacheDeletePublicMCP(mcpID string) error {
  509. if !common.RedisEnabled {
  510. return nil
  511. }
  512. return common.RDB.Del(context.Background(), common.RedisKeyf(PublicMCPCacheKey, mcpID)).Err()
  513. }
  514. func CacheSetPublicMCP(publicMCP *PublicMCPCache) error {
  515. if !common.RedisEnabled {
  516. return nil
  517. }
  518. key := common.RedisKeyf(PublicMCPCacheKey, publicMCP.ID)
  519. pipe := common.RDB.Pipeline()
  520. pipe.HSet(context.Background(), key, publicMCP)
  521. expireTime := SyncFrequency + time.Duration(rand.Int64N(60)-30)*time.Second
  522. pipe.Expire(context.Background(), key, expireTime)
  523. _, err := pipe.Exec(context.Background())
  524. return err
  525. }
  526. func CacheGetPublicMCP(mcpID string) (*PublicMCPCache, error) {
  527. if !common.RedisEnabled {
  528. publicMCP, err := GetPublicMCPByID(mcpID)
  529. if err != nil {
  530. return nil, err
  531. }
  532. return publicMCP.ToPublicMCPCache(), nil
  533. }
  534. cacheKey := common.RedisKeyf(PublicMCPCacheKey, mcpID)
  535. publicMCPCache := &PublicMCPCache{}
  536. err := common.RDB.HGetAll(context.Background(), cacheKey).Scan(publicMCPCache)
  537. if err == nil && publicMCPCache.ID != "" {
  538. return publicMCPCache, nil
  539. } else if err != nil && !errors.Is(err, redis.Nil) {
  540. log.Errorf("get public mcp (%s) from redis error: %s", mcpID, err.Error())
  541. }
  542. publicMCP, err := GetPublicMCPByID(mcpID)
  543. if err != nil {
  544. return nil, err
  545. }
  546. pmc := publicMCP.ToPublicMCPCache()
  547. if err := CacheSetPublicMCP(pmc); err != nil {
  548. log.Error("redis set public mcp error: " + err.Error())
  549. }
  550. return pmc, nil
  551. }
  552. var updatePublicMCPStatusScript = redis.NewScript(`
  553. if redis.call("HExists", KEYS[1], "s") then
  554. redis.call("HSet", KEYS[1], "s", ARGV[1])
  555. end
  556. return redis.status_reply("ok")
  557. `)
  558. func CacheUpdatePublicMCPStatus(mcpID string, status PublicMCPStatus) error {
  559. if !common.RedisEnabled {
  560. return nil
  561. }
  562. return updatePublicMCPStatusScript.Run(context.Background(), common.RDB, []string{common.RedisKeyf(PublicMCPCacheKey, mcpID)}, status).
  563. Err()
  564. }
  565. const (
  566. PublicMCPReusingParamCacheKey = "public_mcp_param:%s:%s" // mcp_id:group_id
  567. )
  568. type PublicMCPReusingParamCache struct {
  569. MCPID string `json:"mcp_id" redis:"m"`
  570. GroupID string `json:"group_id" redis:"g"`
  571. Params redisMap[string, string] `json:"params" redis:"p"`
  572. }
  573. func (p *PublicMCPReusingParam) ToPublicMCPReusingParamCache() PublicMCPReusingParamCache {
  574. return PublicMCPReusingParamCache{
  575. MCPID: p.MCPID,
  576. GroupID: p.GroupID,
  577. Params: p.Params,
  578. }
  579. }
  580. func CacheDeletePublicMCPReusingParam(mcpID, groupID string) error {
  581. if !common.RedisEnabled {
  582. return nil
  583. }
  584. return common.RDB.Del(context.Background(), common.RedisKeyf(PublicMCPReusingParamCacheKey, mcpID, groupID)).
  585. Err()
  586. }
  587. func CacheSetPublicMCPReusingParam(param PublicMCPReusingParamCache) error {
  588. if !common.RedisEnabled {
  589. return nil
  590. }
  591. key := common.RedisKeyf(PublicMCPReusingParamCacheKey, param.MCPID, param.GroupID)
  592. pipe := common.RDB.Pipeline()
  593. pipe.HSet(context.Background(), key, param)
  594. expireTime := SyncFrequency + time.Duration(rand.Int64N(60)-30)*time.Second
  595. pipe.Expire(context.Background(), key, expireTime)
  596. _, err := pipe.Exec(context.Background())
  597. return err
  598. }
  599. func CacheGetPublicMCPReusingParam(mcpID, groupID string) (PublicMCPReusingParamCache, error) {
  600. if groupID == "" {
  601. return PublicMCPReusingParamCache{
  602. MCPID: mcpID,
  603. GroupID: groupID,
  604. Params: make(map[string]string),
  605. }, nil
  606. }
  607. if !common.RedisEnabled {
  608. param, err := GetPublicMCPReusingParam(mcpID, groupID)
  609. if err != nil {
  610. return PublicMCPReusingParamCache{}, err
  611. }
  612. return param.ToPublicMCPReusingParamCache(), nil
  613. }
  614. cacheKey := common.RedisKeyf(PublicMCPReusingParamCacheKey, mcpID, groupID)
  615. paramCache := PublicMCPReusingParamCache{}
  616. err := common.RDB.HGetAll(context.Background(), cacheKey).Scan(&paramCache)
  617. if err == nil && paramCache.MCPID != "" {
  618. return paramCache, nil
  619. } else if err != nil && !errors.Is(err, redis.Nil) {
  620. log.Errorf("get public mcp reusing param (%s:%s) from redis error: %s", mcpID, groupID, err.Error())
  621. }
  622. param, err := GetPublicMCPReusingParam(mcpID, groupID)
  623. if err != nil {
  624. return PublicMCPReusingParamCache{}, err
  625. }
  626. prc := param.ToPublicMCPReusingParamCache()
  627. if err := CacheSetPublicMCPReusingParam(prc); err != nil {
  628. log.Error("redis set public mcp reusing param error: " + err.Error())
  629. }
  630. return prc, nil
  631. }
  632. const (
  633. StoreCacheKey = "storev2:%s:%d:%s" // store_id
  634. )
  635. type StoreCache struct {
  636. ID string `json:"id" redis:"i"`
  637. GroupID string `json:"group_id" redis:"g"`
  638. TokenID int `json:"token_id" redis:"t"`
  639. ChannelID int `json:"channel_id" redis:"c"`
  640. Model string `json:"model" redis:"m"`
  641. ExpiresAt time.Time `json:"expires_at" redis:"e"`
  642. }
  643. func (s *StoreV2) ToStoreCache() *StoreCache {
  644. return &StoreCache{
  645. ID: s.ID,
  646. GroupID: s.GroupID,
  647. TokenID: s.TokenID,
  648. ChannelID: s.ChannelID,
  649. Model: s.Model,
  650. ExpiresAt: s.ExpiresAt,
  651. }
  652. }
  653. func CacheSetStore(store *StoreCache) error {
  654. if !common.RedisEnabled {
  655. return nil
  656. }
  657. key := common.RedisKeyf(StoreCacheKey, store.GroupID, store.TokenID, store.ID)
  658. pipe := common.RDB.Pipeline()
  659. pipe.HSet(context.Background(), key, store)
  660. expireTime := SyncFrequency + time.Duration(rand.Int64N(60)-30)*time.Second
  661. pipe.Expire(context.Background(), key, expireTime)
  662. _, err := pipe.Exec(context.Background())
  663. return err
  664. }
  665. func CacheGetStore(group string, tokenID int, id string) (*StoreCache, error) {
  666. if !common.RedisEnabled {
  667. store, err := GetStore(group, tokenID, id)
  668. if err != nil {
  669. return nil, err
  670. }
  671. return store.ToStoreCache(), nil
  672. }
  673. cacheKey := common.RedisKeyf(StoreCacheKey, group, tokenID, id)
  674. storeCache := &StoreCache{}
  675. err := common.RDB.HGetAll(context.Background(), cacheKey).Scan(storeCache)
  676. if err == nil && storeCache.ID != "" {
  677. return storeCache, nil
  678. }
  679. store, err := GetStore(group, tokenID, id)
  680. if err != nil {
  681. return nil, err
  682. }
  683. sc := store.ToStoreCache()
  684. if err := CacheSetStore(sc); err != nil {
  685. log.Error("redis set store error: " + err.Error())
  686. }
  687. return sc, nil
  688. }
  689. type ModelConfigCache interface {
  690. GetModelConfig(model string) (ModelConfig, bool)
  691. }
  692. // read-only cache
  693. //
  694. type ModelCaches struct {
  695. ModelConfig ModelConfigCache
  696. // map[set][]model
  697. EnabledModelsBySet map[string][]string
  698. // map[set][]modelconfig
  699. EnabledModelConfigsBySet map[string][]ModelConfig
  700. // map[model]modelconfig
  701. EnabledModelConfigsMap map[string]ModelConfig
  702. // map[set]map[model][]channel
  703. EnabledModel2ChannelsBySet map[string]map[string][]*Channel
  704. // map[set]map[model][]channel
  705. DisabledModel2ChannelsBySet map[string]map[string][]*Channel
  706. }
  707. var modelCaches atomic.Pointer[ModelCaches]
  708. func init() {
  709. modelCaches.Store(new(ModelCaches))
  710. }
  711. func LoadModelCaches() *ModelCaches {
  712. return modelCaches.Load()
  713. }
  714. // InitModelConfigAndChannelCache initializes the channel cache from database
  715. func InitModelConfigAndChannelCache() error {
  716. modelConfig, err := initializeModelConfigCache()
  717. if err != nil {
  718. return err
  719. }
  720. // Load enabled channels from database
  721. enabledChannels, err := LoadEnabledChannels()
  722. if err != nil {
  723. return err
  724. }
  725. // Build model to channels map by set
  726. enabledModel2ChannelsBySet := buildModelToChannelsBySetMap(enabledChannels)
  727. // Sort channels by priority within each set
  728. sortChannelsByPriorityBySet(enabledModel2ChannelsBySet)
  729. // Build enabled models and configs by set
  730. enabledModelsBySet, enabledModelConfigsBySet, enabledModelConfigsMap := buildEnabledModelsBySet(
  731. enabledModel2ChannelsBySet,
  732. modelConfig,
  733. )
  734. // Load disabled channels
  735. disabledChannels, err := LoadDisabledChannels()
  736. if err != nil {
  737. return err
  738. }
  739. // Build disabled model to channels map by set
  740. disabledModel2ChannelsBySet := buildModelToChannelsBySetMap(disabledChannels)
  741. // Update global cache atomically
  742. modelCaches.Store(&ModelCaches{
  743. ModelConfig: modelConfig,
  744. EnabledModelsBySet: enabledModelsBySet,
  745. EnabledModelConfigsBySet: enabledModelConfigsBySet,
  746. EnabledModelConfigsMap: enabledModelConfigsMap,
  747. EnabledModel2ChannelsBySet: enabledModel2ChannelsBySet,
  748. DisabledModel2ChannelsBySet: disabledModel2ChannelsBySet,
  749. })
  750. return nil
  751. }
  752. func LoadEnabledChannels() ([]*Channel, error) {
  753. var channels []*Channel
  754. err := DB.Where("status = ?", ChannelStatusEnabled).Find(&channels).Error
  755. if err != nil {
  756. return nil, err
  757. }
  758. for _, channel := range channels {
  759. initializeChannelModels(channel)
  760. initializeChannelModelMapping(channel)
  761. }
  762. return channels, nil
  763. }
  764. func LoadDisabledChannels() ([]*Channel, error) {
  765. var channels []*Channel
  766. err := DB.Where("status = ?", ChannelStatusDisabled).Find(&channels).Error
  767. if err != nil {
  768. return nil, err
  769. }
  770. for _, channel := range channels {
  771. initializeChannelModels(channel)
  772. initializeChannelModelMapping(channel)
  773. }
  774. return channels, nil
  775. }
  776. func LoadChannels() ([]*Channel, error) {
  777. var channels []*Channel
  778. err := DB.Find(&channels).Error
  779. if err != nil {
  780. return nil, err
  781. }
  782. for _, channel := range channels {
  783. initializeChannelModels(channel)
  784. initializeChannelModelMapping(channel)
  785. }
  786. return channels, nil
  787. }
  788. func LoadChannelByID(id int) (*Channel, error) {
  789. var channel Channel
  790. err := DB.First(&channel, id).Error
  791. if err != nil {
  792. return nil, err
  793. }
  794. initializeChannelModels(&channel)
  795. initializeChannelModelMapping(&channel)
  796. return &channel, nil
  797. }
  798. var _ ModelConfigCache = (*modelConfigMapCache)(nil)
  799. type modelConfigMapCache struct {
  800. modelConfigMap map[string]ModelConfig
  801. }
  802. func (m *modelConfigMapCache) GetModelConfig(model string) (ModelConfig, bool) {
  803. config, ok := m.modelConfigMap[model]
  804. return config, ok
  805. }
  806. var _ ModelConfigCache = (*disabledModelConfigCache)(nil)
  807. type disabledModelConfigCache struct {
  808. modelConfigs ModelConfigCache
  809. }
  810. func (d *disabledModelConfigCache) GetModelConfig(model string) (ModelConfig, bool) {
  811. if config, ok := d.modelConfigs.GetModelConfig(model); ok {
  812. return config, true
  813. }
  814. return NewDefaultModelConfig(model), true
  815. }
  816. func initializeModelConfigCache() (ModelConfigCache, error) {
  817. modelConfigs, err := GetAllModelConfigs()
  818. if err != nil {
  819. return nil, err
  820. }
  821. newModelConfigMap := make(map[string]ModelConfig)
  822. for _, modelConfig := range modelConfigs {
  823. newModelConfigMap[modelConfig.Model] = modelConfig
  824. }
  825. configs := &modelConfigMapCache{modelConfigMap: newModelConfigMap}
  826. if config.DisableModelConfig {
  827. return &disabledModelConfigCache{modelConfigs: configs}, nil
  828. }
  829. return configs, nil
  830. }
  831. func initializeChannelModels(channel *Channel) {
  832. if len(channel.Models) == 0 {
  833. channel.Models = config.GetDefaultChannelModels()[int(channel.Type)]
  834. return
  835. }
  836. findedModels, missingModels, err := GetModelConfigWithModels(channel.Models)
  837. if err != nil {
  838. return
  839. }
  840. if len(missingModels) > 0 {
  841. slices.Sort(missingModels)
  842. log.Errorf("model config not found: %v", missingModels)
  843. }
  844. slices.Sort(findedModels)
  845. channel.Models = findedModels
  846. }
  847. func initializeChannelModelMapping(channel *Channel) {
  848. if len(channel.ModelMapping) == 0 {
  849. channel.ModelMapping = config.GetDefaultChannelModelMapping()[int(channel.Type)]
  850. }
  851. }
  852. func buildModelToChannelsBySetMap(channels []*Channel) map[string]map[string][]*Channel {
  853. modelMapBySet := make(map[string]map[string][]*Channel)
  854. for _, channel := range channels {
  855. sets := channel.GetSets()
  856. for _, set := range sets {
  857. if _, ok := modelMapBySet[set]; !ok {
  858. modelMapBySet[set] = make(map[string][]*Channel)
  859. }
  860. for _, model := range channel.Models {
  861. modelMapBySet[set][model] = append(modelMapBySet[set][model], channel)
  862. }
  863. }
  864. }
  865. return modelMapBySet
  866. }
  867. func sortChannelsByPriorityBySet(modelMapBySet map[string]map[string][]*Channel) {
  868. for _, modelMap := range modelMapBySet {
  869. for _, channels := range modelMap {
  870. sort.Slice(channels, func(i, j int) bool {
  871. return channels[i].GetPriority() > channels[j].GetPriority()
  872. })
  873. }
  874. }
  875. }
  876. func buildEnabledModelsBySet(
  877. modelMapBySet map[string]map[string][]*Channel,
  878. modelConfigCache ModelConfigCache,
  879. ) (
  880. map[string][]string,
  881. map[string][]ModelConfig,
  882. map[string]ModelConfig,
  883. ) {
  884. modelsBySet := make(map[string][]string)
  885. modelConfigsBySet := make(map[string][]ModelConfig)
  886. modelConfigsMap := make(map[string]ModelConfig)
  887. for set, modelMap := range modelMapBySet {
  888. models := make([]string, 0)
  889. configs := make([]ModelConfig, 0)
  890. appended := make(map[string]struct{})
  891. for model := range modelMap {
  892. if _, ok := appended[model]; ok {
  893. continue
  894. }
  895. if config, ok := modelConfigCache.GetModelConfig(model); ok {
  896. models = append(models, model)
  897. configs = append(configs, config)
  898. appended[model] = struct{}{}
  899. modelConfigsMap[model] = config
  900. }
  901. }
  902. slices.Sort(models)
  903. slices.SortStableFunc(configs, SortModelConfigsFunc)
  904. modelsBySet[set] = models
  905. modelConfigsBySet[set] = configs
  906. }
  907. return modelsBySet, modelConfigsBySet, modelConfigsMap
  908. }
  909. func SortModelConfigsFunc(i, j ModelConfig) int {
  910. if i.Owner != j.Owner {
  911. if natural.Less(string(i.Owner), string(j.Owner)) {
  912. return -1
  913. }
  914. return 1
  915. }
  916. if i.Type != j.Type {
  917. if i.Type < j.Type {
  918. return -1
  919. }
  920. return 1
  921. }
  922. if i.Model == j.Model {
  923. return 0
  924. }
  925. if natural.Less(i.Model, j.Model) {
  926. return -1
  927. }
  928. return 1
  929. }
  930. func SyncModelConfigAndChannelCache(
  931. ctx context.Context,
  932. wg *sync.WaitGroup,
  933. frequency time.Duration,
  934. ) {
  935. defer wg.Done()
  936. ticker := time.NewTicker(frequency)
  937. defer ticker.Stop()
  938. for {
  939. select {
  940. case <-ctx.Done():
  941. return
  942. case <-ticker.C:
  943. err := InitModelConfigAndChannelCache()
  944. if err != nil {
  945. notify.ErrorThrottle(
  946. "syncModelChannel",
  947. time.Minute*5,
  948. "failed to sync channels",
  949. err.Error(),
  950. )
  951. }
  952. }
  953. }
  954. }