cache.go 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134
  1. package model
  2. import (
  3. "context"
  4. "encoding"
  5. "errors"
  6. "math/rand/v2"
  7. "slices"
  8. "sort"
  9. "sync"
  10. "sync/atomic"
  11. "time"
  12. "github.com/bytedance/sonic"
  13. "github.com/labring/aiproxy/core/common"
  14. "github.com/labring/aiproxy/core/common/config"
  15. "github.com/labring/aiproxy/core/common/conv"
  16. "github.com/labring/aiproxy/core/common/notify"
  17. "github.com/maruel/natural"
  18. "github.com/redis/go-redis/v9"
  19. log "github.com/sirupsen/logrus"
  20. )
  21. const (
  22. SyncFrequency = time.Minute * 3
  23. TokenCacheKey = "token:%s"
  24. GroupCacheKey = "group:%s"
  25. GroupModelTPMKey = "group:%s:model_tpm"
  26. )
  27. var (
  28. _ encoding.BinaryMarshaler = (*redisStringSlice)(nil)
  29. _ redis.Scanner = (*redisStringSlice)(nil)
  30. )
  31. type redisStringSlice []string
  32. func (r *redisStringSlice) ScanRedis(value string) error {
  33. return sonic.Unmarshal(conv.StringToBytes(value), r)
  34. }
  35. func (r redisStringSlice) MarshalBinary() ([]byte, error) {
  36. return sonic.Marshal(r)
  37. }
  38. type redisTime time.Time
  39. var (
  40. _ redis.Scanner = (*redisTime)(nil)
  41. _ encoding.BinaryMarshaler = (*redisTime)(nil)
  42. )
  43. func (t *redisTime) ScanRedis(value string) error {
  44. return (*time.Time)(t).UnmarshalBinary(conv.StringToBytes(value))
  45. }
  46. func (t redisTime) MarshalBinary() ([]byte, error) {
  47. return time.Time(t).MarshalBinary()
  48. }
  49. type TokenCache struct {
  50. ExpiredAt redisTime `json:"expired_at" redis:"e"`
  51. Group string `json:"group" redis:"g"`
  52. Key string `json:"-" redis:"-"`
  53. Name string `json:"name" redis:"n"`
  54. Subnets redisStringSlice `json:"subnets" redis:"s"`
  55. Models redisStringSlice `json:"models" redis:"m"`
  56. ID int `json:"id" redis:"i"`
  57. Status int `json:"status" redis:"st"`
  58. Quota float64 `json:"quota" redis:"q"`
  59. UsedAmount float64 `json:"used_amount" redis:"u"`
  60. availableSets []string
  61. modelsBySet map[string][]string
  62. }
  63. func (t *TokenCache) SetAvailableSets(availableSets []string) {
  64. t.availableSets = availableSets
  65. }
  66. func (t *TokenCache) SetModelsBySet(modelsBySet map[string][]string) {
  67. t.modelsBySet = modelsBySet
  68. }
  69. func (t *TokenCache) ContainsModel(model string) bool {
  70. if len(t.Models) != 0 {
  71. if !slices.Contains(t.Models, model) {
  72. return false
  73. }
  74. }
  75. return containsModel(model, t.availableSets, t.modelsBySet)
  76. }
  77. func containsModel(model string, sets []string, modelsBySet map[string][]string) bool {
  78. for _, set := range sets {
  79. if slices.Contains(modelsBySet[set], model) {
  80. return true
  81. }
  82. }
  83. return false
  84. }
  85. func (t *TokenCache) Range(fn func(model string) bool) {
  86. ranged := make(map[string]struct{})
  87. if len(t.Models) != 0 {
  88. for _, model := range t.Models {
  89. if _, ok := ranged[model]; !ok && containsModel(model, t.availableSets, t.modelsBySet) {
  90. if !fn(model) {
  91. return
  92. }
  93. }
  94. ranged[model] = struct{}{}
  95. }
  96. return
  97. }
  98. for _, set := range t.availableSets {
  99. for _, model := range t.modelsBySet[set] {
  100. if _, ok := ranged[model]; !ok {
  101. if !fn(model) {
  102. return
  103. }
  104. }
  105. ranged[model] = struct{}{}
  106. }
  107. }
  108. }
  109. func (t *Token) ToTokenCache() *TokenCache {
  110. return &TokenCache{
  111. ID: t.ID,
  112. Group: t.GroupID,
  113. Key: t.Key,
  114. Name: t.Name.String(),
  115. Models: t.Models,
  116. Subnets: t.Subnets,
  117. Status: t.Status,
  118. ExpiredAt: redisTime(t.ExpiredAt),
  119. Quota: t.Quota,
  120. UsedAmount: t.UsedAmount,
  121. }
  122. }
  123. func CacheDeleteToken(key string) error {
  124. if !common.RedisEnabled {
  125. return nil
  126. }
  127. return common.RDB.Del(context.Background(), common.RedisKeyf(TokenCacheKey, key)).Err()
  128. }
  129. func CacheSetToken(token *TokenCache) error {
  130. if !common.RedisEnabled {
  131. return nil
  132. }
  133. key := common.RedisKeyf(TokenCacheKey, token.Key)
  134. pipe := common.RDB.Pipeline()
  135. pipe.HSet(context.Background(), key, token)
  136. expireTime := SyncFrequency + time.Duration(rand.Int64N(60)-30)*time.Second
  137. pipe.Expire(context.Background(), key, expireTime)
  138. _, err := pipe.Exec(context.Background())
  139. return err
  140. }
  141. func CacheGetTokenByKey(key string) (*TokenCache, error) {
  142. if !common.RedisEnabled {
  143. token, err := GetTokenByKey(key)
  144. if err != nil {
  145. return nil, err
  146. }
  147. return token.ToTokenCache(), nil
  148. }
  149. cacheKey := common.RedisKeyf(TokenCacheKey, key)
  150. tokenCache := &TokenCache{}
  151. err := common.RDB.HGetAll(context.Background(), cacheKey).Scan(tokenCache)
  152. if err == nil && tokenCache.ID != 0 {
  153. tokenCache.Key = key
  154. return tokenCache, nil
  155. } else if err != nil && !errors.Is(err, redis.Nil) {
  156. log.Errorf("get token (%s) from redis error: %s", key, err.Error())
  157. }
  158. token, err := GetTokenByKey(key)
  159. if err != nil {
  160. return nil, err
  161. }
  162. tc := token.ToTokenCache()
  163. if err := CacheSetToken(tc); err != nil {
  164. log.Error("redis set token error: " + err.Error())
  165. }
  166. return tc, nil
  167. }
  168. var updateTokenUsedAmountOnlyIncreaseScript = redis.NewScript(`
  169. local used_amount = redis.call("HGet", KEYS[1], "ua")
  170. if used_amount == false then
  171. return redis.status_reply("ok")
  172. end
  173. if ARGV[1] < used_amount then
  174. return redis.status_reply("ok")
  175. end
  176. redis.call("HSet", KEYS[1], "ua", ARGV[1])
  177. return redis.status_reply("ok")
  178. `)
  179. func CacheUpdateTokenUsedAmountOnlyIncrease(key string, amount float64) error {
  180. if !common.RedisEnabled {
  181. return nil
  182. }
  183. return updateTokenUsedAmountOnlyIncreaseScript.Run(context.Background(), common.RDB, []string{common.RedisKeyf(TokenCacheKey, key)}, amount).
  184. Err()
  185. }
  186. var updateTokenNameScript = redis.NewScript(`
  187. if redis.call("HExists", KEYS[1], "n") then
  188. redis.call("HSet", KEYS[1], "n", ARGV[1])
  189. end
  190. return redis.status_reply("ok")
  191. `)
  192. func CacheUpdateTokenName(key, name string) error {
  193. if !common.RedisEnabled {
  194. return nil
  195. }
  196. return updateTokenNameScript.Run(context.Background(), common.RDB, []string{common.RedisKeyf(TokenCacheKey, key)}, name).
  197. Err()
  198. }
  199. var updateTokenStatusScript = redis.NewScript(`
  200. if redis.call("HExists", KEYS[1], "st") then
  201. redis.call("HSet", KEYS[1], "st", ARGV[1])
  202. end
  203. return redis.status_reply("ok")
  204. `)
  205. func CacheUpdateTokenStatus(key string, status int) error {
  206. if !common.RedisEnabled {
  207. return nil
  208. }
  209. return updateTokenStatusScript.Run(context.Background(), common.RDB, []string{common.RedisKeyf(TokenCacheKey, key)}, status).
  210. Err()
  211. }
  212. type redisMap[K comparable, V any] map[K]V
  213. var (
  214. _ redis.Scanner = (*redisMap[string, any])(nil)
  215. _ encoding.BinaryMarshaler = (*redisMap[string, any])(nil)
  216. )
  217. func (r *redisMap[K, V]) ScanRedis(value string) error {
  218. return sonic.UnmarshalString(value, r)
  219. }
  220. func (r redisMap[K, V]) MarshalBinary() ([]byte, error) {
  221. return sonic.Marshal(r)
  222. }
  223. type (
  224. redisGroupModelConfigMap = redisMap[string, GroupModelConfig]
  225. )
  226. type GroupCache struct {
  227. ID string `json:"-" redis:"-"`
  228. Status int `json:"status" redis:"st"`
  229. UsedAmount float64 `json:"used_amount" redis:"ua"`
  230. RPMRatio float64 `json:"rpm_ratio" redis:"rpm_r"`
  231. TPMRatio float64 `json:"tpm_ratio" redis:"tpm_r"`
  232. AvailableSets redisStringSlice `json:"available_sets" redis:"ass"`
  233. ModelConfigs redisGroupModelConfigMap `json:"model_configs" redis:"mc"`
  234. BalanceAlertEnabled bool `json:"balance_alert_enabled" redis:"bae"`
  235. BalanceAlertThreshold float64 `json:"balance_alert_threshold" redis:"bat"`
  236. }
  237. func (g *GroupCache) GetAvailableSets() []string {
  238. if len(g.AvailableSets) == 0 {
  239. return []string{ChannelDefaultSet}
  240. }
  241. return g.AvailableSets
  242. }
  243. func (g *Group) ToGroupCache() *GroupCache {
  244. modelConfigs := make(redisGroupModelConfigMap, len(g.GroupModelConfigs))
  245. for _, modelConfig := range g.GroupModelConfigs {
  246. modelConfigs[modelConfig.Model] = modelConfig
  247. }
  248. return &GroupCache{
  249. ID: g.ID,
  250. Status: g.Status,
  251. UsedAmount: g.UsedAmount,
  252. RPMRatio: g.RPMRatio,
  253. TPMRatio: g.TPMRatio,
  254. AvailableSets: g.AvailableSets,
  255. ModelConfigs: modelConfigs,
  256. BalanceAlertEnabled: g.BalanceAlertEnabled,
  257. BalanceAlertThreshold: g.BalanceAlertThreshold,
  258. }
  259. }
  260. func CacheDeleteGroup(id string) error {
  261. if !common.RedisEnabled {
  262. return nil
  263. }
  264. return common.RDB.Del(context.Background(), common.RedisKeyf(GroupCacheKey, id)).Err()
  265. }
  266. var updateGroupRPMRatioScript = redis.NewScript(`
  267. if redis.call("HExists", KEYS[1], "rpm_r") then
  268. redis.call("HSet", KEYS[1], "rpm_r", ARGV[1])
  269. end
  270. return redis.status_reply("ok")
  271. `)
  272. func CacheUpdateGroupRPMRatio(id string, rpmRatio float64) error {
  273. if !common.RedisEnabled {
  274. return nil
  275. }
  276. return updateGroupRPMRatioScript.Run(context.Background(), common.RDB, []string{common.RedisKeyf(GroupCacheKey, id)}, rpmRatio).
  277. Err()
  278. }
  279. var updateGroupTPMRatioScript = redis.NewScript(`
  280. if redis.call("HExists", KEYS[1], "tpm_r") then
  281. redis.call("HSet", KEYS[1], "tpm_r", ARGV[1])
  282. end
  283. return redis.status_reply("ok")
  284. `)
  285. func CacheUpdateGroupTPMRatio(id string, tpmRatio float64) error {
  286. if !common.RedisEnabled {
  287. return nil
  288. }
  289. return updateGroupTPMRatioScript.Run(context.Background(), common.RDB, []string{common.RedisKeyf(GroupCacheKey, id)}, tpmRatio).
  290. Err()
  291. }
  292. var updateGroupStatusScript = redis.NewScript(`
  293. if redis.call("HExists", KEYS[1], "st") then
  294. redis.call("HSet", KEYS[1], "st", ARGV[1])
  295. end
  296. return redis.status_reply("ok")
  297. `)
  298. func CacheUpdateGroupStatus(id string, status int) error {
  299. if !common.RedisEnabled {
  300. return nil
  301. }
  302. return updateGroupStatusScript.Run(context.Background(), common.RDB, []string{common.RedisKeyf(GroupCacheKey, id)}, status).
  303. Err()
  304. }
  305. func CacheSetGroup(group *GroupCache) error {
  306. if !common.RedisEnabled {
  307. return nil
  308. }
  309. key := common.RedisKeyf(GroupCacheKey, group.ID)
  310. pipe := common.RDB.Pipeline()
  311. pipe.HSet(context.Background(), key, group)
  312. expireTime := SyncFrequency + time.Duration(rand.Int64N(60)-30)*time.Second
  313. pipe.Expire(context.Background(), key, expireTime)
  314. _, err := pipe.Exec(context.Background())
  315. return err
  316. }
  317. func CacheGetGroup(id string) (*GroupCache, error) {
  318. if !common.RedisEnabled {
  319. group, err := GetGroupByID(id, true)
  320. if err != nil {
  321. return nil, err
  322. }
  323. return group.ToGroupCache(), nil
  324. }
  325. cacheKey := common.RedisKeyf(GroupCacheKey, id)
  326. groupCache := &GroupCache{}
  327. err := common.RDB.HGetAll(context.Background(), cacheKey).Scan(groupCache)
  328. if err == nil && groupCache.Status != 0 {
  329. groupCache.ID = id
  330. return groupCache, nil
  331. } else if err != nil && !errors.Is(err, redis.Nil) {
  332. log.Errorf("get group (%s) from redis error: %s", id, err.Error())
  333. }
  334. group, err := GetGroupByID(id, true)
  335. if err != nil {
  336. return nil, err
  337. }
  338. gc := group.ToGroupCache()
  339. if err := CacheSetGroup(gc); err != nil {
  340. log.Error("redis set group error: " + err.Error())
  341. }
  342. return gc, nil
  343. }
  344. var updateGroupUsedAmountOnlyIncreaseScript = redis.NewScript(`
  345. local used_amount = redis.call("HGet", KEYS[1], "ua")
  346. if used_amount == false then
  347. return redis.status_reply("ok")
  348. end
  349. if ARGV[1] < used_amount then
  350. return redis.status_reply("ok")
  351. end
  352. redis.call("HSet", KEYS[1], "ua", ARGV[1])
  353. return redis.status_reply("ok")
  354. `)
  355. func CacheUpdateGroupUsedAmountOnlyIncrease(id string, amount float64) error {
  356. if !common.RedisEnabled {
  357. return nil
  358. }
  359. return updateGroupUsedAmountOnlyIncreaseScript.Run(context.Background(), common.RDB, []string{common.RedisKeyf(GroupCacheKey, id)}, amount).
  360. Err()
  361. }
  362. type GroupMCPCache struct {
  363. ID string `json:"id" redis:"i"`
  364. GroupID string `json:"group_id" redis:"g"`
  365. Status GroupMCPStatus `json:"status" redis:"s"`
  366. Type GroupMCPType `json:"type" redis:"t"`
  367. ProxyConfig *GroupMCPProxyConfig `json:"proxy_config" redis:"pc"`
  368. OpenAPIConfig *MCPOpenAPIConfig `json:"openapi_config" redis:"oc"`
  369. }
  370. func (g *GroupMCP) ToGroupMCPCache() *GroupMCPCache {
  371. return &GroupMCPCache{
  372. ID: g.ID,
  373. GroupID: g.GroupID,
  374. Status: g.Status,
  375. Type: g.Type,
  376. ProxyConfig: g.ProxyConfig,
  377. OpenAPIConfig: g.OpenAPIConfig,
  378. }
  379. }
  380. const (
  381. GroupMCPCacheKey = "group_mcp:%s:%s" // group_id:mcp_id
  382. )
  383. func CacheDeleteGroupMCP(groupID, mcpID string) error {
  384. if !common.RedisEnabled {
  385. return nil
  386. }
  387. return common.RDB.Del(context.Background(), common.RedisKeyf(GroupMCPCacheKey, groupID, mcpID)).
  388. Err()
  389. }
  390. func CacheSetGroupMCP(groupMCP *GroupMCPCache) error {
  391. if !common.RedisEnabled {
  392. return nil
  393. }
  394. key := common.RedisKeyf(GroupMCPCacheKey, groupMCP.GroupID, groupMCP.ID)
  395. pipe := common.RDB.Pipeline()
  396. pipe.HSet(context.Background(), key, groupMCP)
  397. expireTime := SyncFrequency + time.Duration(rand.Int64N(60)-30)*time.Second
  398. pipe.Expire(context.Background(), key, expireTime)
  399. _, err := pipe.Exec(context.Background())
  400. return err
  401. }
  402. func CacheGetGroupMCP(groupID, mcpID string) (*GroupMCPCache, error) {
  403. if !common.RedisEnabled {
  404. groupMCP, err := GetGroupMCPByID(mcpID, groupID)
  405. if err != nil {
  406. return nil, err
  407. }
  408. return groupMCP.ToGroupMCPCache(), nil
  409. }
  410. cacheKey := common.RedisKeyf(GroupMCPCacheKey, groupID, mcpID)
  411. groupMCPCache := &GroupMCPCache{}
  412. err := common.RDB.HGetAll(context.Background(), cacheKey).Scan(groupMCPCache)
  413. if err == nil && groupMCPCache.ID != "" {
  414. return groupMCPCache, nil
  415. } else if err != nil && !errors.Is(err, redis.Nil) {
  416. log.Errorf("get group mcp (%s:%s) from redis error: %s", groupID, mcpID, err.Error())
  417. }
  418. groupMCP, err := GetGroupMCPByID(mcpID, groupID)
  419. if err != nil {
  420. return nil, err
  421. }
  422. gmc := groupMCP.ToGroupMCPCache()
  423. if err := CacheSetGroupMCP(gmc); err != nil {
  424. log.Error("redis set group mcp error: " + err.Error())
  425. }
  426. return gmc, nil
  427. }
  428. var updateGroupMCPStatusScript = redis.NewScript(`
  429. if redis.call("HExists", KEYS[1], "s") then
  430. redis.call("HSet", KEYS[1], "s", ARGV[1])
  431. end
  432. return redis.status_reply("ok")
  433. `)
  434. func CacheUpdateGroupMCPStatus(groupID, mcpID string, status GroupMCPStatus) error {
  435. if !common.RedisEnabled {
  436. return nil
  437. }
  438. return updateGroupMCPStatusScript.Run(context.Background(), common.RDB, []string{common.RedisKeyf(GroupMCPCacheKey, groupID, mcpID)}, status).
  439. Err()
  440. }
  441. type PublicMCPCache struct {
  442. ID string `json:"id" redis:"i"`
  443. Status PublicMCPStatus `json:"status" redis:"s"`
  444. Type PublicMCPType `json:"type" redis:"t"`
  445. Price MCPPrice `json:"price" redis:"p"`
  446. ProxyConfig *PublicMCPProxyConfig `json:"proxy_config" redis:"pc"`
  447. OpenAPIConfig *MCPOpenAPIConfig `json:"openapi_config" redis:"oc"`
  448. EmbedConfig *MCPEmbeddingConfig `json:"embed_config" redis:"ec"`
  449. }
  450. func (p *PublicMCP) ToPublicMCPCache() *PublicMCPCache {
  451. return &PublicMCPCache{
  452. ID: p.ID,
  453. Status: p.Status,
  454. Type: p.Type,
  455. Price: p.Price,
  456. ProxyConfig: p.ProxyConfig,
  457. OpenAPIConfig: p.OpenAPIConfig,
  458. EmbedConfig: p.EmbedConfig,
  459. }
  460. }
  461. const (
  462. PublicMCPCacheKey = "public_mcp:%s" // mcp_id
  463. )
  464. func CacheDeletePublicMCP(mcpID string) error {
  465. if !common.RedisEnabled {
  466. return nil
  467. }
  468. return common.RDB.Del(context.Background(), common.RedisKeyf(PublicMCPCacheKey, mcpID)).Err()
  469. }
  470. func CacheSetPublicMCP(publicMCP *PublicMCPCache) error {
  471. if !common.RedisEnabled {
  472. return nil
  473. }
  474. key := common.RedisKeyf(PublicMCPCacheKey, publicMCP.ID)
  475. pipe := common.RDB.Pipeline()
  476. pipe.HSet(context.Background(), key, publicMCP)
  477. expireTime := SyncFrequency + time.Duration(rand.Int64N(60)-30)*time.Second
  478. pipe.Expire(context.Background(), key, expireTime)
  479. _, err := pipe.Exec(context.Background())
  480. return err
  481. }
  482. func CacheGetPublicMCP(mcpID string) (*PublicMCPCache, error) {
  483. if !common.RedisEnabled {
  484. publicMCP, err := GetPublicMCPByID(mcpID)
  485. if err != nil {
  486. return nil, err
  487. }
  488. return publicMCP.ToPublicMCPCache(), nil
  489. }
  490. cacheKey := common.RedisKeyf(PublicMCPCacheKey, mcpID)
  491. publicMCPCache := &PublicMCPCache{}
  492. err := common.RDB.HGetAll(context.Background(), cacheKey).Scan(publicMCPCache)
  493. if err == nil && publicMCPCache.ID != "" {
  494. return publicMCPCache, nil
  495. } else if err != nil && !errors.Is(err, redis.Nil) {
  496. log.Errorf("get public mcp (%s) from redis error: %s", mcpID, err.Error())
  497. }
  498. publicMCP, err := GetPublicMCPByID(mcpID)
  499. if err != nil {
  500. return nil, err
  501. }
  502. pmc := publicMCP.ToPublicMCPCache()
  503. if err := CacheSetPublicMCP(pmc); err != nil {
  504. log.Error("redis set public mcp error: " + err.Error())
  505. }
  506. return pmc, nil
  507. }
  508. var updatePublicMCPStatusScript = redis.NewScript(`
  509. if redis.call("HExists", KEYS[1], "s") then
  510. redis.call("HSet", KEYS[1], "s", ARGV[1])
  511. end
  512. return redis.status_reply("ok")
  513. `)
  514. func CacheUpdatePublicMCPStatus(mcpID string, status PublicMCPStatus) error {
  515. if !common.RedisEnabled {
  516. return nil
  517. }
  518. return updatePublicMCPStatusScript.Run(context.Background(), common.RDB, []string{common.RedisKeyf(PublicMCPCacheKey, mcpID)}, status).
  519. Err()
  520. }
  521. const (
  522. PublicMCPReusingParamCacheKey = "public_mcp_param:%s:%s" // mcp_id:group_id
  523. )
  524. type PublicMCPReusingParamCache struct {
  525. MCPID string `json:"mcp_id" redis:"m"`
  526. GroupID string `json:"group_id" redis:"g"`
  527. Params redisMap[string, string] `json:"params" redis:"p"`
  528. }
  529. func (p *PublicMCPReusingParam) ToPublicMCPReusingParamCache() PublicMCPReusingParamCache {
  530. return PublicMCPReusingParamCache{
  531. MCPID: p.MCPID,
  532. GroupID: p.GroupID,
  533. Params: p.Params,
  534. }
  535. }
  536. func CacheDeletePublicMCPReusingParam(mcpID, groupID string) error {
  537. if !common.RedisEnabled {
  538. return nil
  539. }
  540. return common.RDB.Del(context.Background(), common.RedisKeyf(PublicMCPReusingParamCacheKey, mcpID, groupID)).
  541. Err()
  542. }
  543. func CacheSetPublicMCPReusingParam(param PublicMCPReusingParamCache) error {
  544. if !common.RedisEnabled {
  545. return nil
  546. }
  547. key := common.RedisKeyf(PublicMCPReusingParamCacheKey, param.MCPID, param.GroupID)
  548. pipe := common.RDB.Pipeline()
  549. pipe.HSet(context.Background(), key, param)
  550. expireTime := SyncFrequency + time.Duration(rand.Int64N(60)-30)*time.Second
  551. pipe.Expire(context.Background(), key, expireTime)
  552. _, err := pipe.Exec(context.Background())
  553. return err
  554. }
  555. func CacheGetPublicMCPReusingParam(mcpID, groupID string) (PublicMCPReusingParamCache, error) {
  556. if groupID == "" {
  557. return PublicMCPReusingParamCache{
  558. MCPID: mcpID,
  559. GroupID: groupID,
  560. Params: make(map[string]string),
  561. }, nil
  562. }
  563. if !common.RedisEnabled {
  564. param, err := GetPublicMCPReusingParam(mcpID, groupID)
  565. if err != nil {
  566. return PublicMCPReusingParamCache{}, err
  567. }
  568. return param.ToPublicMCPReusingParamCache(), nil
  569. }
  570. cacheKey := common.RedisKeyf(PublicMCPReusingParamCacheKey, mcpID, groupID)
  571. paramCache := PublicMCPReusingParamCache{}
  572. err := common.RDB.HGetAll(context.Background(), cacheKey).Scan(&paramCache)
  573. if err == nil && paramCache.MCPID != "" {
  574. return paramCache, nil
  575. } else if err != nil && !errors.Is(err, redis.Nil) {
  576. log.Errorf("get public mcp reusing param (%s:%s) from redis error: %s", mcpID, groupID, err.Error())
  577. }
  578. param, err := GetPublicMCPReusingParam(mcpID, groupID)
  579. if err != nil {
  580. return PublicMCPReusingParamCache{}, err
  581. }
  582. prc := param.ToPublicMCPReusingParamCache()
  583. if err := CacheSetPublicMCPReusingParam(prc); err != nil {
  584. log.Error("redis set public mcp reusing param error: " + err.Error())
  585. }
  586. return prc, nil
  587. }
  588. const (
  589. StoreCacheKey = "store:%s" // store_id
  590. )
  591. type StoreCache struct {
  592. ID string `json:"id" redis:"i"`
  593. GroupID string `json:"group_id" redis:"g"`
  594. TokenID int `json:"token_id" redis:"t"`
  595. ChannelID int `json:"channel_id" redis:"c"`
  596. Model string `json:"model" redis:"m"`
  597. ExpiresAt time.Time `json:"expires_at" redis:"e"`
  598. }
  599. func (s *Store) ToStoreCache() *StoreCache {
  600. return &StoreCache{
  601. ID: s.ID,
  602. GroupID: s.GroupID,
  603. TokenID: s.TokenID,
  604. ChannelID: s.ChannelID,
  605. Model: s.Model,
  606. ExpiresAt: s.ExpiresAt,
  607. }
  608. }
  609. func CacheSetStore(store *StoreCache) error {
  610. if !common.RedisEnabled {
  611. return nil
  612. }
  613. key := common.RedisKeyf(StoreCacheKey, store.ID)
  614. pipe := common.RDB.Pipeline()
  615. pipe.HSet(context.Background(), key, store)
  616. expireTime := SyncFrequency + time.Duration(rand.Int64N(60)-30)*time.Second
  617. pipe.Expire(context.Background(), key, expireTime)
  618. _, err := pipe.Exec(context.Background())
  619. return err
  620. }
  621. func CacheGetStore(id string) (*StoreCache, error) {
  622. if !common.RedisEnabled {
  623. store, err := GetStore(id)
  624. if err != nil {
  625. return nil, err
  626. }
  627. return store.ToStoreCache(), nil
  628. }
  629. cacheKey := common.RedisKeyf(StoreCacheKey, id)
  630. storeCache := &StoreCache{}
  631. err := common.RDB.HGetAll(context.Background(), cacheKey).Scan(storeCache)
  632. if err == nil && storeCache.ID != "" {
  633. return storeCache, nil
  634. }
  635. store, err := GetStore(id)
  636. if err != nil {
  637. return nil, err
  638. }
  639. sc := store.ToStoreCache()
  640. if err := CacheSetStore(sc); err != nil {
  641. log.Error("redis set store error: " + err.Error())
  642. }
  643. return sc, nil
  644. }
  645. type ModelConfigCache interface {
  646. GetModelConfig(model string) (ModelConfig, bool)
  647. }
  648. // read-only cache
  649. //
  650. type ModelCaches struct {
  651. ModelConfig ModelConfigCache
  652. // map[set][]model
  653. EnabledModelsBySet map[string][]string
  654. // map[set][]modelconfig
  655. EnabledModelConfigsBySet map[string][]ModelConfig
  656. // map[model]modelconfig
  657. EnabledModelConfigsMap map[string]ModelConfig
  658. // map[set]map[model][]channel
  659. EnabledModel2ChannelsBySet map[string]map[string][]*Channel
  660. // map[set]map[model][]channel
  661. DisabledModel2ChannelsBySet map[string]map[string][]*Channel
  662. }
  663. var modelCaches atomic.Pointer[ModelCaches]
  664. func init() {
  665. modelCaches.Store(new(ModelCaches))
  666. }
  667. func LoadModelCaches() *ModelCaches {
  668. return modelCaches.Load()
  669. }
  670. // InitModelConfigAndChannelCache initializes the channel cache from database
  671. func InitModelConfigAndChannelCache() error {
  672. modelConfig, err := initializeModelConfigCache()
  673. if err != nil {
  674. return err
  675. }
  676. // Load enabled channels from database
  677. enabledChannels, err := LoadEnabledChannels()
  678. if err != nil {
  679. return err
  680. }
  681. // Build model to channels map by set
  682. enabledModel2ChannelsBySet := buildModelToChannelsBySetMap(enabledChannels)
  683. // Sort channels by priority within each set
  684. sortChannelsByPriorityBySet(enabledModel2ChannelsBySet)
  685. // Build enabled models and configs by set
  686. enabledModelsBySet, enabledModelConfigsBySet, enabledModelConfigsMap := buildEnabledModelsBySet(
  687. enabledModel2ChannelsBySet,
  688. modelConfig,
  689. )
  690. // Load disabled channels
  691. disabledChannels, err := LoadDisabledChannels()
  692. if err != nil {
  693. return err
  694. }
  695. // Build disabled model to channels map by set
  696. disabledModel2ChannelsBySet := buildModelToChannelsBySetMap(disabledChannels)
  697. // Update global cache atomically
  698. modelCaches.Store(&ModelCaches{
  699. ModelConfig: modelConfig,
  700. EnabledModelsBySet: enabledModelsBySet,
  701. EnabledModelConfigsBySet: enabledModelConfigsBySet,
  702. EnabledModelConfigsMap: enabledModelConfigsMap,
  703. EnabledModel2ChannelsBySet: enabledModel2ChannelsBySet,
  704. DisabledModel2ChannelsBySet: disabledModel2ChannelsBySet,
  705. })
  706. return nil
  707. }
  708. func LoadEnabledChannels() ([]*Channel, error) {
  709. var channels []*Channel
  710. err := DB.Where("status = ?", ChannelStatusEnabled).Find(&channels).Error
  711. if err != nil {
  712. return nil, err
  713. }
  714. for _, channel := range channels {
  715. initializeChannelModels(channel)
  716. initializeChannelModelMapping(channel)
  717. }
  718. return channels, nil
  719. }
  720. func LoadDisabledChannels() ([]*Channel, error) {
  721. var channels []*Channel
  722. err := DB.Where("status = ?", ChannelStatusDisabled).Find(&channels).Error
  723. if err != nil {
  724. return nil, err
  725. }
  726. for _, channel := range channels {
  727. initializeChannelModels(channel)
  728. initializeChannelModelMapping(channel)
  729. }
  730. return channels, nil
  731. }
  732. func LoadChannels() ([]*Channel, error) {
  733. var channels []*Channel
  734. err := DB.Find(&channels).Error
  735. if err != nil {
  736. return nil, err
  737. }
  738. for _, channel := range channels {
  739. initializeChannelModels(channel)
  740. initializeChannelModelMapping(channel)
  741. }
  742. return channels, nil
  743. }
  744. func LoadChannelByID(id int) (*Channel, error) {
  745. var channel Channel
  746. err := DB.First(&channel, id).Error
  747. if err != nil {
  748. return nil, err
  749. }
  750. initializeChannelModels(&channel)
  751. initializeChannelModelMapping(&channel)
  752. return &channel, nil
  753. }
  754. var _ ModelConfigCache = (*modelConfigMapCache)(nil)
  755. type modelConfigMapCache struct {
  756. modelConfigMap map[string]ModelConfig
  757. }
  758. func (m *modelConfigMapCache) GetModelConfig(model string) (ModelConfig, bool) {
  759. config, ok := m.modelConfigMap[model]
  760. return config, ok
  761. }
  762. var _ ModelConfigCache = (*disabledModelConfigCache)(nil)
  763. type disabledModelConfigCache struct {
  764. modelConfigs ModelConfigCache
  765. }
  766. func (d *disabledModelConfigCache) GetModelConfig(model string) (ModelConfig, bool) {
  767. if config, ok := d.modelConfigs.GetModelConfig(model); ok {
  768. return config, true
  769. }
  770. return NewDefaultModelConfig(model), true
  771. }
  772. func initializeModelConfigCache() (ModelConfigCache, error) {
  773. modelConfigs, err := GetAllModelConfigs()
  774. if err != nil {
  775. return nil, err
  776. }
  777. newModelConfigMap := make(map[string]ModelConfig)
  778. for _, modelConfig := range modelConfigs {
  779. newModelConfigMap[modelConfig.Model] = modelConfig
  780. }
  781. configs := &modelConfigMapCache{modelConfigMap: newModelConfigMap}
  782. if config.DisableModelConfig {
  783. return &disabledModelConfigCache{modelConfigs: configs}, nil
  784. }
  785. return configs, nil
  786. }
  787. func initializeChannelModels(channel *Channel) {
  788. if len(channel.Models) == 0 {
  789. channel.Models = config.GetDefaultChannelModels()[int(channel.Type)]
  790. return
  791. }
  792. findedModels, missingModels, err := GetModelConfigWithModels(channel.Models)
  793. if err != nil {
  794. return
  795. }
  796. if len(missingModels) > 0 {
  797. slices.Sort(missingModels)
  798. log.Errorf("model config not found: %v", missingModels)
  799. }
  800. slices.Sort(findedModels)
  801. channel.Models = findedModels
  802. }
  803. func initializeChannelModelMapping(channel *Channel) {
  804. if len(channel.ModelMapping) == 0 {
  805. channel.ModelMapping = config.GetDefaultChannelModelMapping()[int(channel.Type)]
  806. }
  807. }
  808. func buildModelToChannelsBySetMap(channels []*Channel) map[string]map[string][]*Channel {
  809. modelMapBySet := make(map[string]map[string][]*Channel)
  810. for _, channel := range channels {
  811. sets := channel.GetSets()
  812. for _, set := range sets {
  813. if _, ok := modelMapBySet[set]; !ok {
  814. modelMapBySet[set] = make(map[string][]*Channel)
  815. }
  816. for _, model := range channel.Models {
  817. modelMapBySet[set][model] = append(modelMapBySet[set][model], channel)
  818. }
  819. }
  820. }
  821. return modelMapBySet
  822. }
  823. func sortChannelsByPriorityBySet(modelMapBySet map[string]map[string][]*Channel) {
  824. for _, modelMap := range modelMapBySet {
  825. for _, channels := range modelMap {
  826. sort.Slice(channels, func(i, j int) bool {
  827. return channels[i].GetPriority() > channels[j].GetPriority()
  828. })
  829. }
  830. }
  831. }
  832. func buildEnabledModelsBySet(
  833. modelMapBySet map[string]map[string][]*Channel,
  834. modelConfigCache ModelConfigCache,
  835. ) (
  836. map[string][]string,
  837. map[string][]ModelConfig,
  838. map[string]ModelConfig,
  839. ) {
  840. modelsBySet := make(map[string][]string)
  841. modelConfigsBySet := make(map[string][]ModelConfig)
  842. modelConfigsMap := make(map[string]ModelConfig)
  843. for set, modelMap := range modelMapBySet {
  844. models := make([]string, 0)
  845. configs := make([]ModelConfig, 0)
  846. appended := make(map[string]struct{})
  847. for model := range modelMap {
  848. if _, ok := appended[model]; ok {
  849. continue
  850. }
  851. if config, ok := modelConfigCache.GetModelConfig(model); ok {
  852. models = append(models, model)
  853. configs = append(configs, config)
  854. appended[model] = struct{}{}
  855. modelConfigsMap[model] = config
  856. }
  857. }
  858. slices.Sort(models)
  859. slices.SortStableFunc(configs, SortModelConfigsFunc)
  860. modelsBySet[set] = models
  861. modelConfigsBySet[set] = configs
  862. }
  863. return modelsBySet, modelConfigsBySet, modelConfigsMap
  864. }
  865. func SortModelConfigsFunc(i, j ModelConfig) int {
  866. if i.Owner != j.Owner {
  867. if natural.Less(string(i.Owner), string(j.Owner)) {
  868. return -1
  869. }
  870. return 1
  871. }
  872. if i.Type != j.Type {
  873. if i.Type < j.Type {
  874. return -1
  875. }
  876. return 1
  877. }
  878. if i.Model == j.Model {
  879. return 0
  880. }
  881. if natural.Less(i.Model, j.Model) {
  882. return -1
  883. }
  884. return 1
  885. }
  886. func SyncModelConfigAndChannelCache(
  887. ctx context.Context,
  888. wg *sync.WaitGroup,
  889. frequency time.Duration,
  890. ) {
  891. defer wg.Done()
  892. ticker := time.NewTicker(frequency)
  893. defer ticker.Stop()
  894. for {
  895. select {
  896. case <-ctx.Done():
  897. return
  898. case <-ticker.C:
  899. err := InitModelConfigAndChannelCache()
  900. if err != nil {
  901. notify.ErrorThrottle(
  902. "syncModelChannel",
  903. time.Minute,
  904. "failed to sync channels",
  905. err.Error(),
  906. )
  907. }
  908. }
  909. }
  910. }