model.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580
  1. package monitor
  2. import (
  3. "context"
  4. "fmt"
  5. "strconv"
  6. "strings"
  7. "time"
  8. "github.com/labring/aiproxy/core/common"
  9. "github.com/labring/aiproxy/core/common/config"
  10. "github.com/redis/go-redis/v9"
  11. )
  12. // Redis key prefixes and patterns
  13. const (
  14. bannedKeySuffix = ":banned"
  15. statsKeySuffix = ":stats"
  16. modelTotalStatsSuffix = ":total_stats"
  17. channelKeyPart = ":channel:"
  18. )
  19. func modelKeyPrefix() string {
  20. return common.RedisKey("model:")
  21. }
  22. // Redis scripts
  23. var (
  24. addRequestScript = redis.NewScript(addRequestLuaScript)
  25. getErrorRateScript = redis.NewScript(getErrorRateLuaScript)
  26. clearChannelModelErrorsScript = redis.NewScript(clearChannelModelErrorsLuaScript)
  27. clearChannelAllModelErrorsScript = redis.NewScript(clearChannelAllModelErrorsLuaScript)
  28. clearAllModelErrorsScript = redis.NewScript(clearAllModelErrorsLuaScript)
  29. )
  30. func GetModelsErrorRate(ctx context.Context) (map[string]float64, error) {
  31. if !common.RedisEnabled {
  32. return memModelMonitor.GetModelsErrorRate(ctx)
  33. }
  34. result := make(map[string]float64)
  35. pattern := modelKeyPrefix() + "*" + modelTotalStatsSuffix
  36. now := time.Now().UnixMilli()
  37. iter := common.RDB.Scan(ctx, 0, pattern, 0).Iterator()
  38. for iter.Next(ctx) {
  39. key := iter.Val()
  40. model := strings.TrimPrefix(key, modelKeyPrefix())
  41. model = strings.TrimSuffix(model, modelTotalStatsSuffix)
  42. rate, err := getErrorRateScript.Run(
  43. ctx,
  44. common.RDB,
  45. []string{key},
  46. now,
  47. ).Float64()
  48. if err != nil {
  49. return nil, err
  50. }
  51. result[model] = rate
  52. }
  53. if err := iter.Err(); err != nil {
  54. return nil, err
  55. }
  56. return result, nil
  57. }
  58. // AddRequest adds a request record and checks if channel should be banned
  59. // warnErrorRate: threshold for warning (default 30%)
  60. // maxErrorRate: threshold for banning (0 means no banning)
  61. func AddRequest(
  62. ctx context.Context,
  63. model string,
  64. channelID int64,
  65. isError, tryBan bool,
  66. warnErrorRate,
  67. maxErrorRate float64,
  68. ) (beyondThreshold, banExecution bool, err error) {
  69. // Set default warning threshold if not specified
  70. if warnErrorRate <= 0 {
  71. warnErrorRate = config.GetDefaultWarnNotifyErrorRate()
  72. }
  73. if !common.RedisEnabled {
  74. beyondThreshold, banExecution = memModelMonitor.AddRequest(
  75. model,
  76. channelID,
  77. isError,
  78. tryBan,
  79. warnErrorRate,
  80. maxErrorRate,
  81. )
  82. return beyondThreshold, banExecution, nil
  83. }
  84. errorFlag := 0
  85. if isError {
  86. errorFlag = 1
  87. } else {
  88. tryBan = false
  89. }
  90. now := time.Now().UnixMilli()
  91. val, err := addRequestScript.Run(
  92. ctx,
  93. common.RDB,
  94. []string{common.RedisKeyPrefix(), model},
  95. channelID,
  96. errorFlag,
  97. now,
  98. warnErrorRate,
  99. maxErrorRate,
  100. maxErrorRate > 0,
  101. tryBan,
  102. ).Int64()
  103. if err != nil {
  104. return false, false, err
  105. }
  106. return val == 3, val == 1, nil
  107. }
  108. func buildStatsKey(model, channelID string) string {
  109. return fmt.Sprintf(
  110. "%s%s%s%v%s",
  111. modelKeyPrefix(),
  112. model,
  113. channelKeyPart,
  114. channelID,
  115. statsKeySuffix,
  116. )
  117. }
  118. func getModelChannelID(key string) (string, int64, bool) {
  119. content := strings.TrimPrefix(key, modelKeyPrefix())
  120. content = strings.TrimSuffix(content, statsKeySuffix)
  121. model, channelIDStr, ok := strings.Cut(content, channelKeyPart)
  122. if !ok {
  123. return "", 0, false
  124. }
  125. channelID, err := strconv.ParseInt(channelIDStr, 10, 64)
  126. if err != nil {
  127. return "", 0, false
  128. }
  129. return model, channelID, true
  130. }
  131. // GetChannelModelErrorRates gets error rates for a specific channel
  132. func GetChannelModelErrorRates(ctx context.Context, channelID int64) (map[string]float64, error) {
  133. if !common.RedisEnabled {
  134. return memModelMonitor.GetChannelModelErrorRates(ctx, channelID)
  135. }
  136. result := make(map[string]float64)
  137. pattern := buildStatsKey("*", strconv.FormatInt(channelID, 10))
  138. now := time.Now().UnixMilli()
  139. iter := common.RDB.Scan(ctx, 0, pattern, 0).Iterator()
  140. for iter.Next(ctx) {
  141. key := iter.Val()
  142. model, _, ok := getModelChannelID(key)
  143. if !ok {
  144. continue
  145. }
  146. rate, err := getErrorRateScript.Run(
  147. ctx,
  148. common.RDB,
  149. []string{key},
  150. now,
  151. ).Float64()
  152. if err != nil {
  153. return nil, err
  154. }
  155. result[model] = rate
  156. }
  157. if err := iter.Err(); err != nil {
  158. return nil, err
  159. }
  160. return result, nil
  161. }
  162. func GetModelChannelErrorRate(ctx context.Context, model string) (map[int64]float64, error) {
  163. if !common.RedisEnabled {
  164. return memModelMonitor.GetModelChannelErrorRate(ctx, model)
  165. }
  166. result := make(map[int64]float64)
  167. pattern := buildStatsKey(model, "*")
  168. now := time.Now().UnixMilli()
  169. iter := common.RDB.Scan(ctx, 0, pattern, 0).Iterator()
  170. for iter.Next(ctx) {
  171. key := iter.Val()
  172. _, channelID, ok := getModelChannelID(key)
  173. if !ok {
  174. continue
  175. }
  176. rate, err := getErrorRateScript.Run(
  177. ctx,
  178. common.RDB,
  179. []string{key},
  180. now,
  181. ).Float64()
  182. if err != nil {
  183. return nil, err
  184. }
  185. result[channelID] = rate
  186. }
  187. if err := iter.Err(); err != nil {
  188. return nil, err
  189. }
  190. return result, nil
  191. }
  192. // GetBannedChannelsWithModel gets banned channels for a specific model
  193. func GetBannedChannelsWithModel(ctx context.Context, model string) ([]int64, error) {
  194. if !common.RedisEnabled {
  195. return memModelMonitor.GetBannedChannelsWithModel(ctx, model)
  196. }
  197. result := []int64{}
  198. prefix := modelKeyPrefix() + model + channelKeyPart
  199. pattern := prefix + "*" + bannedKeySuffix
  200. iter := common.RDB.Scan(ctx, 0, pattern, 0).Iterator()
  201. for iter.Next(ctx) {
  202. key := iter.Val()
  203. channelIDStr := strings.TrimSuffix(strings.TrimPrefix(key, prefix), bannedKeySuffix)
  204. channelID, err := strconv.ParseInt(channelIDStr, 10, 64)
  205. if err != nil {
  206. continue
  207. }
  208. result = append(result, channelID)
  209. }
  210. if err := iter.Err(); err != nil {
  211. return nil, err
  212. }
  213. return result, nil
  214. }
  215. // GetBannedChannelsMapWithModel gets banned channels for a specific model as a map for efficient lookups
  216. func GetBannedChannelsMapWithModel(ctx context.Context, model string) (map[int64]struct{}, error) {
  217. if !common.RedisEnabled {
  218. return memModelMonitor.GetBannedChannelsMapWithModel(ctx, model)
  219. }
  220. result := make(map[int64]struct{})
  221. prefix := modelKeyPrefix() + model + channelKeyPart
  222. pattern := prefix + "*" + bannedKeySuffix
  223. iter := common.RDB.Scan(ctx, 0, pattern, 0).Iterator()
  224. for iter.Next(ctx) {
  225. key := iter.Val()
  226. channelIDStr := strings.TrimSuffix(strings.TrimPrefix(key, prefix), bannedKeySuffix)
  227. channelID, err := strconv.ParseInt(channelIDStr, 10, 64)
  228. if err != nil {
  229. continue
  230. }
  231. result[channelID] = struct{}{}
  232. }
  233. if err := iter.Err(); err != nil {
  234. return nil, err
  235. }
  236. return result, nil
  237. }
  238. // ClearChannelModelErrors clears errors for a specific channel and model
  239. func ClearChannelModelErrors(ctx context.Context, model string, channelID int) error {
  240. if !common.RedisEnabled {
  241. return memModelMonitor.ClearChannelModelErrors(ctx, model, channelID)
  242. }
  243. return clearChannelModelErrorsScript.Run(
  244. ctx,
  245. common.RDB,
  246. []string{common.RedisKeyPrefix(), model},
  247. strconv.Itoa(channelID),
  248. ).Err()
  249. }
  250. // ClearChannelAllModelErrors clears all errors for a specific channel
  251. func ClearChannelAllModelErrors(ctx context.Context, channelID int) error {
  252. if !common.RedisEnabled {
  253. return memModelMonitor.ClearChannelAllModelErrors(ctx, channelID)
  254. }
  255. return clearChannelAllModelErrorsScript.Run(
  256. ctx,
  257. common.RDB,
  258. []string{common.RedisKeyPrefix()},
  259. strconv.Itoa(channelID),
  260. ).Err()
  261. }
  262. // ClearAllModelErrors clears all error records
  263. func ClearAllModelErrors(ctx context.Context) error {
  264. if !common.RedisEnabled {
  265. return memModelMonitor.ClearAllModelErrors(ctx)
  266. }
  267. return clearAllModelErrorsScript.Run(ctx, common.RDB, []string{common.RedisKeyPrefix()}).Err()
  268. }
  269. // GetAllBannedModelChannels gets all banned channels for all models
  270. func GetAllBannedModelChannels(ctx context.Context) (map[string][]int64, error) {
  271. if !common.RedisEnabled {
  272. return memModelMonitor.GetAllBannedModelChannels(ctx)
  273. }
  274. result := make(map[string][]int64)
  275. pattern := modelKeyPrefix() + "*" + channelKeyPart + "*" + bannedKeySuffix
  276. iter := common.RDB.Scan(ctx, 0, pattern, 0).Iterator()
  277. for iter.Next(ctx) {
  278. key := iter.Val()
  279. parts := strings.TrimPrefix(key, modelKeyPrefix())
  280. parts = strings.TrimSuffix(parts, bannedKeySuffix)
  281. model, channelIDStr, ok := strings.Cut(parts, channelKeyPart)
  282. if !ok {
  283. continue
  284. }
  285. channelID, err := strconv.ParseInt(channelIDStr, 10, 64)
  286. if err != nil {
  287. continue
  288. }
  289. if _, exists := result[model]; !exists {
  290. result[model] = []int64{}
  291. }
  292. result[model] = append(result[model], channelID)
  293. }
  294. if err := iter.Err(); err != nil {
  295. return nil, err
  296. }
  297. return result, nil
  298. }
  299. // GetAllChannelModelErrorRates gets error rates for all channels and models
  300. func GetAllChannelModelErrorRates(ctx context.Context) (map[int64]map[string]float64, error) {
  301. if !common.RedisEnabled {
  302. return memModelMonitor.GetAllChannelModelErrorRates(ctx)
  303. }
  304. result := make(map[int64]map[string]float64)
  305. pattern := buildStatsKey("*", "*")
  306. now := time.Now().UnixMilli()
  307. iter := common.RDB.Scan(ctx, 0, pattern, 0).Iterator()
  308. for iter.Next(ctx) {
  309. key := iter.Val()
  310. model, channelID, ok := getModelChannelID(key)
  311. if !ok {
  312. continue
  313. }
  314. rate, err := getErrorRateScript.Run(
  315. ctx,
  316. common.RDB,
  317. []string{key},
  318. now,
  319. ).Float64()
  320. if err != nil {
  321. return nil, err
  322. }
  323. if _, exists := result[channelID]; !exists {
  324. result[channelID] = make(map[string]float64)
  325. }
  326. result[channelID][model] = rate
  327. }
  328. if err := iter.Err(); err != nil {
  329. return nil, err
  330. }
  331. return result, nil
  332. }
  333. // Lua scripts
  334. const (
  335. addRequestLuaScript = `
  336. local prefix = KEYS[1]
  337. local model = KEYS[2]
  338. local channel_id = ARGV[1]
  339. local is_error = tonumber(ARGV[2])
  340. local now_ts = tonumber(ARGV[3])
  341. local warn_error_rate = tonumber(ARGV[4])
  342. local max_error_rate = tonumber(ARGV[5])
  343. local can_ban = tonumber(ARGV[6])
  344. local try_ban = tonumber(ARGV[7])
  345. local banned_key = prefix .. ":model:" .. model .. ":channel:" .. channel_id .. ":banned"
  346. local stats_key = prefix .. ":model:" .. model .. ":channel:" .. channel_id .. ":stats"
  347. local model_stats_key = prefix .. ":model:" .. model .. ":total_stats"
  348. local maxSliceCount = 12
  349. local statsExpiry = maxSliceCount * 10 * 1000
  350. local banExpiry = 5 * 60 * 1000
  351. local current_slice = math.floor(now_ts / 10 / 1000)
  352. local function parse_req_err(value)
  353. if not value then return 0, 0 end
  354. local r, e = value:match("^(%d+):(%d+)$")
  355. return tonumber(r) or 0, tonumber(e) or 0
  356. end
  357. local function update_stats(key)
  358. local req, err = parse_req_err(redis.call("HGET", key, current_slice))
  359. req = req + 1
  360. err = err + (is_error == 1 and 1 or 0)
  361. redis.call("HSET", key, current_slice, req .. ":" .. err)
  362. redis.call("PEXPIRE", key, statsExpiry)
  363. return req, err
  364. end
  365. local function get_clean_req_err(key)
  366. local total_req, total_err = 0, 0
  367. local min_valid_slice = current_slice - maxSliceCount
  368. local all_slices = redis.call("HGETALL", key)
  369. for i = 1, #all_slices, 2 do
  370. local slice = tonumber(all_slices[i])
  371. if slice < min_valid_slice then
  372. redis.call("HDEL", key, all_slices[i])
  373. else
  374. local req, err = parse_req_err(all_slices[i+1])
  375. total_req = total_req + req
  376. total_err = total_err + err
  377. end
  378. end
  379. return total_req, total_err
  380. end
  381. update_stats(stats_key)
  382. update_stats(model_stats_key)
  383. local function check_channel_error()
  384. local already_banned = redis.call("EXISTS", banned_key) == 1
  385. if try_ban == 1 and can_ban == 1 then
  386. if already_banned then
  387. return 2
  388. end
  389. redis.call("SET", banned_key, 1)
  390. redis.call("PEXPIRE", banned_key, banExpiry)
  391. return 1
  392. end
  393. local total_req, total_err = get_clean_req_err(stats_key)
  394. if total_req < 20 then
  395. return 0
  396. end
  397. local error_rate = total_err / total_req
  398. -- Check if we should ban (only if max_error_rate is set and exceeded)
  399. if can_ban == 1 and error_rate >= max_error_rate then
  400. if already_banned then
  401. return 3 -- Beyond threshold but already banned
  402. end
  403. redis.call("SET", banned_key, 1)
  404. redis.call("PEXPIRE", banned_key, banExpiry)
  405. return 1 -- Ban executed
  406. elseif error_rate >= warn_error_rate then
  407. return 3 -- Beyond warning threshold but not banning
  408. else
  409. return 0 -- All good
  410. end
  411. end
  412. return check_channel_error()
  413. `
  414. getErrorRateLuaScript = `
  415. local stats_key = KEYS[1]
  416. local now_ts = tonumber(ARGV[1])
  417. local maxSliceCount = 12
  418. local current_slice = math.floor(now_ts / 10 / 1000)
  419. local function parse_req_err(value)
  420. if not value then return 0, 0 end
  421. local r, e = value:match("^(%d+):(%d+)$")
  422. return tonumber(r) or 0, tonumber(e) or 0
  423. end
  424. local function get_clean_req_err(key)
  425. local total_req, total_err = 0, 0
  426. local min_valid_slice = current_slice - maxSliceCount
  427. local all_slices = redis.call("HGETALL", key)
  428. for i = 1, #all_slices, 2 do
  429. local slice = tonumber(all_slices[i])
  430. if slice < min_valid_slice then
  431. redis.call("HDEL", key, all_slices[i])
  432. else
  433. local req, err = parse_req_err(all_slices[i+1])
  434. total_req = total_req + req
  435. total_err = total_err + err
  436. end
  437. end
  438. return total_req, total_err
  439. end
  440. local total_req, total_err = get_clean_req_err(stats_key)
  441. if total_req < 20 then return 0 end
  442. return string.format("%.2f", total_err / total_req)
  443. `
  444. clearChannelModelErrorsLuaScript = `
  445. local prefix = KEYS[1]
  446. local model = KEYS[2]
  447. local channel_id = ARGV[1]
  448. local stats_key = prefix .. ":model:" .. model .. ":channel:" .. channel_id .. ":stats"
  449. local banned_key = prefix .. ":model:" .. model .. ":channel:" .. channel_id .. ":banned"
  450. redis.call("DEL", stats_key)
  451. redis.call("DEL", banned_key)
  452. return redis.status_reply("ok")
  453. `
  454. clearChannelAllModelErrorsLuaScript = `
  455. local prefix = KEYS[1]
  456. local function del_keys(pattern)
  457. local keys = redis.call("KEYS", pattern)
  458. if #keys > 0 then redis.call("DEL", unpack(keys)) end
  459. end
  460. local channel_id = ARGV[1]
  461. local stats_pattern = prefix .. ":model:*:channel:" .. channel_id .. ":stats"
  462. local banned_pattern = prefix .. ":model:*:channel:" .. channel_id .. ":banned"
  463. del_keys(stats_pattern)
  464. del_keys(banned_pattern)
  465. return redis.status_reply("ok")
  466. `
  467. clearAllModelErrorsLuaScript = `
  468. local prefix = KEYS[1]
  469. local function del_keys(pattern)
  470. local keys = redis.call("KEYS", pattern)
  471. if #keys > 0 then redis.call("DEL", unpack(keys)) end
  472. end
  473. del_keys(prefix .. ":model:*:channel:*:stats")
  474. del_keys(prefix .. ":model:*:channel:*:banned")
  475. return redis.status_reply("ok")
  476. `
  477. )