model.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533
  1. package monitor
  2. import (
  3. "context"
  4. "fmt"
  5. "strconv"
  6. "strings"
  7. "time"
  8. "github.com/labring/aiproxy/core/common"
  9. "github.com/redis/go-redis/v9"
  10. )
  11. // Redis key prefixes and patterns
  12. const (
  13. bannedKeySuffix = ":banned"
  14. statsKeySuffix = ":stats"
  15. modelTotalStatsSuffix = ":total_stats"
  16. channelKeyPart = ":channel:"
  17. )
  18. func modelKeyPrefix() string {
  19. return common.RedisKey("model:")
  20. }
  21. // Redis scripts
  22. var (
  23. addRequestScript = redis.NewScript(addRequestLuaScript)
  24. getErrorRateScript = redis.NewScript(getErrorRateLuaScript)
  25. clearChannelModelErrorsScript = redis.NewScript(clearChannelModelErrorsLuaScript)
  26. clearChannelAllModelErrorsScript = redis.NewScript(clearChannelAllModelErrorsLuaScript)
  27. clearAllModelErrorsScript = redis.NewScript(clearAllModelErrorsLuaScript)
  28. )
  29. func GetModelsErrorRate(ctx context.Context) (map[string]float64, error) {
  30. if !common.RedisEnabled {
  31. return memModelMonitor.GetModelsErrorRate(ctx)
  32. }
  33. result := make(map[string]float64)
  34. pattern := modelKeyPrefix() + "*" + modelTotalStatsSuffix
  35. now := time.Now().UnixMilli()
  36. iter := common.RDB.Scan(ctx, 0, pattern, 0).Iterator()
  37. for iter.Next(ctx) {
  38. key := iter.Val()
  39. model := strings.TrimPrefix(key, modelKeyPrefix())
  40. model = strings.TrimSuffix(model, modelTotalStatsSuffix)
  41. rate, err := getErrorRateScript.Run(
  42. ctx,
  43. common.RDB,
  44. []string{key},
  45. now,
  46. ).Float64()
  47. if err != nil {
  48. return nil, err
  49. }
  50. result[model] = rate
  51. }
  52. if err := iter.Err(); err != nil {
  53. return nil, err
  54. }
  55. return result, nil
  56. }
  57. // AddRequest adds a request record and checks if channel should be banned
  58. func AddRequest(
  59. ctx context.Context,
  60. model string,
  61. channelID int64,
  62. isError, tryBan bool,
  63. maxErrorRate float64,
  64. ) (beyondThreshold, banExecution bool, err error) {
  65. if !common.RedisEnabled {
  66. beyondThreshold, banExecution = memModelMonitor.AddRequest(
  67. model,
  68. channelID,
  69. isError,
  70. tryBan,
  71. maxErrorRate,
  72. )
  73. return beyondThreshold, banExecution, nil
  74. }
  75. errorFlag := 0
  76. if isError {
  77. errorFlag = 1
  78. } else {
  79. tryBan = false
  80. }
  81. now := time.Now().UnixMilli()
  82. val, err := addRequestScript.Run(
  83. ctx,
  84. common.RDB,
  85. []string{common.RedisKeyPrefix(), model},
  86. channelID,
  87. errorFlag,
  88. now,
  89. maxErrorRate,
  90. maxErrorRate > 0,
  91. tryBan,
  92. ).Int64()
  93. if err != nil {
  94. return false, false, err
  95. }
  96. return val == 3, val == 1, nil
  97. }
  98. func buildStatsKey(model, channelID string) string {
  99. return fmt.Sprintf(
  100. "%s%s%s%v%s",
  101. modelKeyPrefix(),
  102. model,
  103. channelKeyPart,
  104. channelID,
  105. statsKeySuffix,
  106. )
  107. }
  108. func getModelChannelID(key string) (string, int64, bool) {
  109. content := strings.TrimPrefix(key, modelKeyPrefix())
  110. content = strings.TrimSuffix(content, statsKeySuffix)
  111. model, channelIDStr, ok := strings.Cut(content, channelKeyPart)
  112. if !ok {
  113. return "", 0, false
  114. }
  115. channelID, err := strconv.ParseInt(channelIDStr, 10, 64)
  116. if err != nil {
  117. return "", 0, false
  118. }
  119. return model, channelID, true
  120. }
  121. // GetChannelModelErrorRates gets error rates for a specific channel
  122. func GetChannelModelErrorRates(ctx context.Context, channelID int64) (map[string]float64, error) {
  123. if !common.RedisEnabled {
  124. return memModelMonitor.GetChannelModelErrorRates(ctx, channelID)
  125. }
  126. result := make(map[string]float64)
  127. pattern := buildStatsKey("*", strconv.FormatInt(channelID, 10))
  128. now := time.Now().UnixMilli()
  129. iter := common.RDB.Scan(ctx, 0, pattern, 0).Iterator()
  130. for iter.Next(ctx) {
  131. key := iter.Val()
  132. model, _, ok := getModelChannelID(key)
  133. if !ok {
  134. continue
  135. }
  136. rate, err := getErrorRateScript.Run(
  137. ctx,
  138. common.RDB,
  139. []string{key},
  140. now,
  141. ).Float64()
  142. if err != nil {
  143. return nil, err
  144. }
  145. result[model] = rate
  146. }
  147. if err := iter.Err(); err != nil {
  148. return nil, err
  149. }
  150. return result, nil
  151. }
  152. func GetModelChannelErrorRate(ctx context.Context, model string) (map[int64]float64, error) {
  153. if !common.RedisEnabled {
  154. return memModelMonitor.GetModelChannelErrorRate(ctx, model)
  155. }
  156. result := make(map[int64]float64)
  157. pattern := buildStatsKey(model, "*")
  158. now := time.Now().UnixMilli()
  159. iter := common.RDB.Scan(ctx, 0, pattern, 0).Iterator()
  160. for iter.Next(ctx) {
  161. key := iter.Val()
  162. _, channelID, ok := getModelChannelID(key)
  163. if !ok {
  164. continue
  165. }
  166. rate, err := getErrorRateScript.Run(
  167. ctx,
  168. common.RDB,
  169. []string{key},
  170. now,
  171. ).Float64()
  172. if err != nil {
  173. return nil, err
  174. }
  175. result[channelID] = rate
  176. }
  177. if err := iter.Err(); err != nil {
  178. return nil, err
  179. }
  180. return result, nil
  181. }
  182. // GetBannedChannelsWithModel gets banned channels for a specific model
  183. func GetBannedChannelsWithModel(ctx context.Context, model string) ([]int64, error) {
  184. if !common.RedisEnabled {
  185. return memModelMonitor.GetBannedChannelsWithModel(ctx, model)
  186. }
  187. result := []int64{}
  188. prefix := modelKeyPrefix() + model + channelKeyPart
  189. pattern := prefix + "*" + bannedKeySuffix
  190. iter := common.RDB.Scan(ctx, 0, pattern, 0).Iterator()
  191. for iter.Next(ctx) {
  192. key := iter.Val()
  193. channelIDStr := strings.TrimSuffix(strings.TrimPrefix(key, prefix), bannedKeySuffix)
  194. channelID, err := strconv.ParseInt(channelIDStr, 10, 64)
  195. if err != nil {
  196. continue
  197. }
  198. result = append(result, channelID)
  199. }
  200. if err := iter.Err(); err != nil {
  201. return nil, err
  202. }
  203. return result, nil
  204. }
  205. // ClearChannelModelErrors clears errors for a specific channel and model
  206. func ClearChannelModelErrors(ctx context.Context, model string, channelID int) error {
  207. if !common.RedisEnabled {
  208. return memModelMonitor.ClearChannelModelErrors(ctx, model, channelID)
  209. }
  210. return clearChannelModelErrorsScript.Run(
  211. ctx,
  212. common.RDB,
  213. []string{common.RedisKeyPrefix(), model},
  214. strconv.Itoa(channelID),
  215. ).Err()
  216. }
  217. // ClearChannelAllModelErrors clears all errors for a specific channel
  218. func ClearChannelAllModelErrors(ctx context.Context, channelID int) error {
  219. if !common.RedisEnabled {
  220. return memModelMonitor.ClearChannelAllModelErrors(ctx, channelID)
  221. }
  222. return clearChannelAllModelErrorsScript.Run(
  223. ctx,
  224. common.RDB,
  225. []string{common.RedisKeyPrefix()},
  226. strconv.Itoa(channelID),
  227. ).Err()
  228. }
  229. // ClearAllModelErrors clears all error records
  230. func ClearAllModelErrors(ctx context.Context) error {
  231. if !common.RedisEnabled {
  232. return memModelMonitor.ClearAllModelErrors(ctx)
  233. }
  234. return clearAllModelErrorsScript.Run(ctx, common.RDB, []string{common.RedisKeyPrefix()}).Err()
  235. }
  236. // GetAllBannedModelChannels gets all banned channels for all models
  237. func GetAllBannedModelChannels(ctx context.Context) (map[string][]int64, error) {
  238. if !common.RedisEnabled {
  239. return memModelMonitor.GetAllBannedModelChannels(ctx)
  240. }
  241. result := make(map[string][]int64)
  242. pattern := modelKeyPrefix() + "*" + channelKeyPart + "*" + bannedKeySuffix
  243. iter := common.RDB.Scan(ctx, 0, pattern, 0).Iterator()
  244. for iter.Next(ctx) {
  245. key := iter.Val()
  246. parts := strings.TrimPrefix(key, modelKeyPrefix())
  247. parts = strings.TrimSuffix(parts, bannedKeySuffix)
  248. model, channelIDStr, ok := strings.Cut(parts, channelKeyPart)
  249. if !ok {
  250. continue
  251. }
  252. channelID, err := strconv.ParseInt(channelIDStr, 10, 64)
  253. if err != nil {
  254. continue
  255. }
  256. if _, exists := result[model]; !exists {
  257. result[model] = []int64{}
  258. }
  259. result[model] = append(result[model], channelID)
  260. }
  261. if err := iter.Err(); err != nil {
  262. return nil, err
  263. }
  264. return result, nil
  265. }
  266. // GetAllChannelModelErrorRates gets error rates for all channels and models
  267. func GetAllChannelModelErrorRates(ctx context.Context) (map[int64]map[string]float64, error) {
  268. if !common.RedisEnabled {
  269. return memModelMonitor.GetAllChannelModelErrorRates(ctx)
  270. }
  271. result := make(map[int64]map[string]float64)
  272. pattern := buildStatsKey("*", "*")
  273. now := time.Now().UnixMilli()
  274. iter := common.RDB.Scan(ctx, 0, pattern, 0).Iterator()
  275. for iter.Next(ctx) {
  276. key := iter.Val()
  277. model, channelID, ok := getModelChannelID(key)
  278. if !ok {
  279. continue
  280. }
  281. rate, err := getErrorRateScript.Run(
  282. ctx,
  283. common.RDB,
  284. []string{key},
  285. now,
  286. ).Float64()
  287. if err != nil {
  288. return nil, err
  289. }
  290. if _, exists := result[channelID]; !exists {
  291. result[channelID] = make(map[string]float64)
  292. }
  293. result[channelID][model] = rate
  294. }
  295. if err := iter.Err(); err != nil {
  296. return nil, err
  297. }
  298. return result, nil
  299. }
  300. // Lua scripts
  301. const (
  302. addRequestLuaScript = `
  303. local prefix = KEYS[1]
  304. local model = KEYS[2]
  305. local channel_id = ARGV[1]
  306. local is_error = tonumber(ARGV[2])
  307. local now_ts = tonumber(ARGV[3])
  308. local max_error_rate = tonumber(ARGV[4])
  309. local can_ban = tonumber(ARGV[5])
  310. local try_ban = tonumber(ARGV[6])
  311. local banned_key = prefix .. ":model:" .. model .. ":channel:" .. channel_id .. ":banned"
  312. local stats_key = prefix .. ":model:" .. model .. ":channel:" .. channel_id .. ":stats"
  313. local model_stats_key = prefix .. ":model:" .. model .. ":total_stats"
  314. local maxSliceCount = 12
  315. local statsExpiry = maxSliceCount * 10 * 1000
  316. local banExpiry = 5 * 60 * 1000
  317. local current_slice = math.floor(now_ts / 10 / 1000)
  318. local function parse_req_err(value)
  319. if not value then return 0, 0 end
  320. local r, e = value:match("^(%d+):(%d+)$")
  321. return tonumber(r) or 0, tonumber(e) or 0
  322. end
  323. local function update_stats(key)
  324. local req, err = parse_req_err(redis.call("HGET", key, current_slice))
  325. req = req + 1
  326. err = err + (is_error == 1 and 1 or 0)
  327. redis.call("HSET", key, current_slice, req .. ":" .. err)
  328. redis.call("PEXPIRE", key, statsExpiry)
  329. return req, err
  330. end
  331. local function get_clean_req_err(key)
  332. local total_req, total_err = 0, 0
  333. local min_valid_slice = current_slice - maxSliceCount
  334. local all_slices = redis.call("HGETALL", key)
  335. for i = 1, #all_slices, 2 do
  336. local slice = tonumber(all_slices[i])
  337. if slice < min_valid_slice then
  338. redis.call("HDEL", key, all_slices[i])
  339. else
  340. local req, err = parse_req_err(all_slices[i+1])
  341. total_req = total_req + req
  342. total_err = total_err + err
  343. end
  344. end
  345. return total_req, total_err
  346. end
  347. update_stats(stats_key)
  348. update_stats(model_stats_key)
  349. local function check_channel_error()
  350. local already_banned = redis.call("EXISTS", banned_key) == 1
  351. if try_ban == 1 and can_ban == 1 then
  352. if already_banned then
  353. return 2
  354. end
  355. redis.call("SET", banned_key, 1)
  356. redis.call("PEXPIRE", banned_key, banExpiry)
  357. return 1
  358. end
  359. local total_req, total_err = get_clean_req_err(stats_key)
  360. if total_req < 20 then
  361. return 0
  362. end
  363. if (total_err / total_req) < max_error_rate then
  364. return 0
  365. else
  366. if can_ban == 0 or already_banned then
  367. return 3
  368. end
  369. redis.call("SET", banned_key, 1)
  370. redis.call("PEXPIRE", banned_key, banExpiry)
  371. return 1
  372. end
  373. end
  374. return check_channel_error()
  375. `
  376. getErrorRateLuaScript = `
  377. local stats_key = KEYS[1]
  378. local now_ts = tonumber(ARGV[1])
  379. local maxSliceCount = 12
  380. local current_slice = math.floor(now_ts / 10 / 1000)
  381. local function parse_req_err(value)
  382. if not value then return 0, 0 end
  383. local r, e = value:match("^(%d+):(%d+)$")
  384. return tonumber(r) or 0, tonumber(e) or 0
  385. end
  386. local function get_clean_req_err(key)
  387. local total_req, total_err = 0, 0
  388. local min_valid_slice = current_slice - maxSliceCount
  389. local all_slices = redis.call("HGETALL", key)
  390. for i = 1, #all_slices, 2 do
  391. local slice = tonumber(all_slices[i])
  392. if slice < min_valid_slice then
  393. redis.call("HDEL", key, all_slices[i])
  394. else
  395. local req, err = parse_req_err(all_slices[i+1])
  396. total_req = total_req + req
  397. total_err = total_err + err
  398. end
  399. end
  400. return total_req, total_err
  401. end
  402. local total_req, total_err = get_clean_req_err(stats_key)
  403. if total_req < 20 then return 0 end
  404. return string.format("%.2f", total_err / total_req)
  405. `
  406. clearChannelModelErrorsLuaScript = `
  407. local prefix = KEYS[1]
  408. local model = KEYS[2]
  409. local channel_id = ARGV[1]
  410. local stats_key = prefix .. ":model:" .. model .. ":channel:" .. channel_id .. ":stats"
  411. local banned_key = prefix .. ":model:" .. model .. ":channel:" .. channel_id .. ":banned"
  412. redis.call("DEL", stats_key)
  413. redis.call("DEL", banned_key)
  414. return redis.status_reply("ok")
  415. `
  416. clearChannelAllModelErrorsLuaScript = `
  417. local prefix = KEYS[1]
  418. local function del_keys(pattern)
  419. local keys = redis.call("KEYS", pattern)
  420. if #keys > 0 then redis.call("DEL", unpack(keys)) end
  421. end
  422. local channel_id = ARGV[1]
  423. local stats_pattern = prefix .. ":model:*:channel:" .. channel_id .. ":stats"
  424. local banned_pattern = prefix .. ":model:*:channel:" .. channel_id .. ":banned"
  425. del_keys(stats_pattern)
  426. del_keys(banned_pattern)
  427. return redis.status_reply("ok")
  428. `
  429. clearAllModelErrorsLuaScript = `
  430. local prefix = KEYS[1]
  431. local function del_keys(pattern)
  432. local keys = redis.call("KEYS", pattern)
  433. if #keys > 0 then redis.call("DEL", unpack(keys)) end
  434. end
  435. del_keys(prefix .. ":model:*:channel:*:stats")
  436. del_keys(prefix .. ":model:*:channel:*:banned")
  437. return redis.status_reply("ok")
  438. `
  439. )