model.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517
  1. package monitor
  2. import (
  3. "context"
  4. "fmt"
  5. "strconv"
  6. "strings"
  7. "time"
  8. "github.com/labring/aiproxy/core/common"
  9. "github.com/redis/go-redis/v9"
  10. )
  11. // Redis key prefixes and patterns
  12. const (
  13. modelKeyPrefix = "model:"
  14. bannedKeySuffix = ":banned"
  15. statsKeySuffix = ":stats"
  16. modelTotalStatsSuffix = ":total_stats"
  17. channelKeyPart = ":channel:"
  18. )
  19. // Redis scripts
  20. var (
  21. addRequestScript = redis.NewScript(addRequestLuaScript)
  22. getErrorRateScript = redis.NewScript(getErrorRateLuaScript)
  23. clearChannelModelErrorsScript = redis.NewScript(clearChannelModelErrorsLuaScript)
  24. clearChannelAllModelErrorsScript = redis.NewScript(clearChannelAllModelErrorsLuaScript)
  25. clearAllModelErrorsScript = redis.NewScript(clearAllModelErrorsLuaScript)
  26. )
  27. // GetModelErrorRate gets error rate for a specific model across all channels
  28. func GetModelsErrorRate(ctx context.Context) (map[string]float64, error) {
  29. if !common.RedisEnabled {
  30. return memModelMonitor.GetModelsErrorRate(ctx)
  31. }
  32. result := make(map[string]float64)
  33. pattern := modelKeyPrefix + "*" + modelTotalStatsSuffix
  34. now := time.Now().UnixMilli()
  35. iter := common.RDB.Scan(ctx, 0, pattern, 0).Iterator()
  36. for iter.Next(ctx) {
  37. key := iter.Val()
  38. model := strings.TrimPrefix(key, modelKeyPrefix)
  39. model = strings.TrimSuffix(model, modelTotalStatsSuffix)
  40. rate, err := getErrorRateScript.Run(
  41. ctx,
  42. common.RDB,
  43. []string{key},
  44. now,
  45. ).Float64()
  46. if err != nil {
  47. return nil, err
  48. }
  49. result[model] = rate
  50. }
  51. if err := iter.Err(); err != nil {
  52. return nil, err
  53. }
  54. return result, nil
  55. }
  56. // AddRequest adds a request record and checks if channel should be banned
  57. func AddRequest(
  58. ctx context.Context,
  59. model string,
  60. channelID int64,
  61. isError, tryBan bool,
  62. maxErrorRate float64,
  63. ) (beyondThreshold, banExecution bool, err error) {
  64. if !common.RedisEnabled {
  65. beyondThreshold, banExecution = memModelMonitor.AddRequest(
  66. model,
  67. channelID,
  68. isError,
  69. tryBan,
  70. maxErrorRate,
  71. )
  72. return beyondThreshold, banExecution, nil
  73. }
  74. errorFlag := 0
  75. if isError {
  76. errorFlag = 1
  77. } else {
  78. tryBan = false
  79. }
  80. now := time.Now().UnixMilli()
  81. val, err := addRequestScript.Run(
  82. ctx,
  83. common.RDB,
  84. []string{model},
  85. channelID,
  86. errorFlag,
  87. now,
  88. maxErrorRate,
  89. maxErrorRate > 0,
  90. tryBan,
  91. ).Int64()
  92. if err != nil {
  93. return false, false, err
  94. }
  95. return val == 3, val == 1, nil
  96. }
  97. func buildStatsKey(model, channelID string) string {
  98. return fmt.Sprintf(
  99. "%s%s%s%v%s",
  100. modelKeyPrefix,
  101. model,
  102. channelKeyPart,
  103. channelID,
  104. statsKeySuffix,
  105. )
  106. }
  107. func getModelChannelID(key string) (string, int64, bool) {
  108. content := strings.TrimPrefix(key, modelKeyPrefix)
  109. content = strings.TrimSuffix(content, statsKeySuffix)
  110. model, channelIDStr, ok := strings.Cut(content, channelKeyPart)
  111. if !ok {
  112. return "", 0, false
  113. }
  114. channelID, err := strconv.ParseInt(channelIDStr, 10, 64)
  115. if err != nil {
  116. return "", 0, false
  117. }
  118. return model, channelID, true
  119. }
  120. // GetChannelModelErrorRates gets error rates for a specific channel
  121. func GetChannelModelErrorRates(ctx context.Context, channelID int64) (map[string]float64, error) {
  122. if !common.RedisEnabled {
  123. return memModelMonitor.GetChannelModelErrorRates(ctx, channelID)
  124. }
  125. result := make(map[string]float64)
  126. pattern := buildStatsKey("*", strconv.FormatInt(channelID, 10))
  127. now := time.Now().UnixMilli()
  128. iter := common.RDB.Scan(ctx, 0, pattern, 0).Iterator()
  129. for iter.Next(ctx) {
  130. key := iter.Val()
  131. model, _, ok := getModelChannelID(key)
  132. if !ok {
  133. continue
  134. }
  135. rate, err := getErrorRateScript.Run(
  136. ctx,
  137. common.RDB,
  138. []string{key},
  139. now,
  140. ).Float64()
  141. if err != nil {
  142. return nil, err
  143. }
  144. result[model] = rate
  145. }
  146. if err := iter.Err(); err != nil {
  147. return nil, err
  148. }
  149. return result, nil
  150. }
  151. func GetModelChannelErrorRate(ctx context.Context, model string) (map[int64]float64, error) {
  152. if !common.RedisEnabled {
  153. return memModelMonitor.GetModelChannelErrorRate(ctx, model)
  154. }
  155. result := make(map[int64]float64)
  156. pattern := buildStatsKey(model, "*")
  157. now := time.Now().UnixMilli()
  158. iter := common.RDB.Scan(ctx, 0, pattern, 0).Iterator()
  159. for iter.Next(ctx) {
  160. key := iter.Val()
  161. _, channelID, ok := getModelChannelID(key)
  162. if !ok {
  163. continue
  164. }
  165. rate, err := getErrorRateScript.Run(
  166. ctx,
  167. common.RDB,
  168. []string{key},
  169. now,
  170. ).Float64()
  171. if err != nil {
  172. return nil, err
  173. }
  174. result[channelID] = rate
  175. }
  176. if err := iter.Err(); err != nil {
  177. return nil, err
  178. }
  179. return result, nil
  180. }
  181. // GetBannedChannelsWithModel gets banned channels for a specific model
  182. func GetBannedChannelsWithModel(ctx context.Context, model string) ([]int64, error) {
  183. if !common.RedisEnabled {
  184. return memModelMonitor.GetBannedChannelsWithModel(ctx, model)
  185. }
  186. result := []int64{}
  187. prefix := modelKeyPrefix + model + channelKeyPart
  188. pattern := prefix + "*" + bannedKeySuffix
  189. iter := common.RDB.Scan(ctx, 0, pattern, 0).Iterator()
  190. for iter.Next(ctx) {
  191. key := iter.Val()
  192. channelIDStr := strings.TrimSuffix(strings.TrimPrefix(key, prefix), bannedKeySuffix)
  193. channelID, err := strconv.ParseInt(channelIDStr, 10, 64)
  194. if err != nil {
  195. continue
  196. }
  197. result = append(result, channelID)
  198. }
  199. if err := iter.Err(); err != nil {
  200. return nil, err
  201. }
  202. return result, nil
  203. }
  204. // ClearChannelModelErrors clears errors for a specific channel and model
  205. func ClearChannelModelErrors(ctx context.Context, model string, channelID int) error {
  206. if !common.RedisEnabled {
  207. return memModelMonitor.ClearChannelModelErrors(ctx, model, channelID)
  208. }
  209. return clearChannelModelErrorsScript.Run(
  210. ctx,
  211. common.RDB,
  212. []string{model},
  213. strconv.Itoa(channelID),
  214. ).Err()
  215. }
  216. // ClearChannelAllModelErrors clears all errors for a specific channel
  217. func ClearChannelAllModelErrors(ctx context.Context, channelID int) error {
  218. if !common.RedisEnabled {
  219. return memModelMonitor.ClearChannelAllModelErrors(ctx, channelID)
  220. }
  221. return clearChannelAllModelErrorsScript.Run(
  222. ctx,
  223. common.RDB,
  224. []string{},
  225. strconv.Itoa(channelID),
  226. ).Err()
  227. }
  228. // ClearAllModelErrors clears all error records
  229. func ClearAllModelErrors(ctx context.Context) error {
  230. if !common.RedisEnabled {
  231. return memModelMonitor.ClearAllModelErrors(ctx)
  232. }
  233. return clearAllModelErrorsScript.Run(ctx, common.RDB, []string{}).Err()
  234. }
  235. // GetAllBannedModelChannels gets all banned channels for all models
  236. func GetAllBannedModelChannels(ctx context.Context) (map[string][]int64, error) {
  237. if !common.RedisEnabled {
  238. return memModelMonitor.GetAllBannedModelChannels(ctx)
  239. }
  240. result := make(map[string][]int64)
  241. pattern := modelKeyPrefix + "*" + channelKeyPart + "*" + bannedKeySuffix
  242. iter := common.RDB.Scan(ctx, 0, pattern, 0).Iterator()
  243. for iter.Next(ctx) {
  244. key := iter.Val()
  245. parts := strings.TrimPrefix(key, modelKeyPrefix)
  246. parts = strings.TrimSuffix(parts, bannedKeySuffix)
  247. model, channelIDStr, ok := strings.Cut(parts, channelKeyPart)
  248. if !ok {
  249. continue
  250. }
  251. channelID, err := strconv.ParseInt(channelIDStr, 10, 64)
  252. if err != nil {
  253. continue
  254. }
  255. if _, exists := result[model]; !exists {
  256. result[model] = []int64{}
  257. }
  258. result[model] = append(result[model], channelID)
  259. }
  260. if err := iter.Err(); err != nil {
  261. return nil, err
  262. }
  263. return result, nil
  264. }
  265. // GetAllChannelModelErrorRates gets error rates for all channels and models
  266. func GetAllChannelModelErrorRates(ctx context.Context) (map[int64]map[string]float64, error) {
  267. if !common.RedisEnabled {
  268. return memModelMonitor.GetAllChannelModelErrorRates(ctx)
  269. }
  270. result := make(map[int64]map[string]float64)
  271. pattern := buildStatsKey("*", "*")
  272. now := time.Now().UnixMilli()
  273. iter := common.RDB.Scan(ctx, 0, pattern, 0).Iterator()
  274. for iter.Next(ctx) {
  275. key := iter.Val()
  276. model, channelID, ok := getModelChannelID(key)
  277. if !ok {
  278. continue
  279. }
  280. rate, err := getErrorRateScript.Run(
  281. ctx,
  282. common.RDB,
  283. []string{key},
  284. now,
  285. ).Float64()
  286. if err != nil {
  287. return nil, err
  288. }
  289. if _, exists := result[channelID]; !exists {
  290. result[channelID] = make(map[string]float64)
  291. }
  292. result[channelID][model] = rate
  293. }
  294. if err := iter.Err(); err != nil {
  295. return nil, err
  296. }
  297. return result, nil
  298. }
  299. // Lua scripts
  300. const (
  301. addRequestLuaScript = `
  302. local model = KEYS[1]
  303. local channel_id = ARGV[1]
  304. local is_error = tonumber(ARGV[2])
  305. local now_ts = tonumber(ARGV[3])
  306. local max_error_rate = tonumber(ARGV[4])
  307. local can_ban = tonumber(ARGV[5])
  308. local try_ban = tonumber(ARGV[6])
  309. local banned_key = "model:" .. model .. ":channel:" .. channel_id .. ":banned"
  310. local stats_key = "model:" .. model .. ":channel:" .. channel_id .. ":stats"
  311. local model_stats_key = "model:" .. model .. ":total_stats"
  312. local maxSliceCount = 12
  313. local statsExpiry = maxSliceCount * 10 * 1000
  314. local banExpiry = 5 * 60 * 1000
  315. local current_slice = math.floor(now_ts / 10 / 1000)
  316. local function parse_req_err(value)
  317. if not value then return 0, 0 end
  318. local r, e = value:match("^(%d+):(%d+)$")
  319. return tonumber(r) or 0, tonumber(e) or 0
  320. end
  321. local function update_stats(key)
  322. local req, err = parse_req_err(redis.call("HGET", key, current_slice))
  323. req = req + 1
  324. err = err + (is_error == 1 and 1 or 0)
  325. redis.call("HSET", key, current_slice, req .. ":" .. err)
  326. redis.call("PEXPIRE", key, statsExpiry)
  327. return req, err
  328. end
  329. local function get_clean_req_err(key)
  330. local total_req, total_err = 0, 0
  331. local min_valid_slice = current_slice - maxSliceCount
  332. local all_slices = redis.call("HGETALL", key)
  333. for i = 1, #all_slices, 2 do
  334. local slice = tonumber(all_slices[i])
  335. if slice < min_valid_slice then
  336. redis.call("HDEL", key, all_slices[i])
  337. else
  338. local req, err = parse_req_err(all_slices[i+1])
  339. total_req = total_req + req
  340. total_err = total_err + err
  341. end
  342. end
  343. return total_req, total_err
  344. end
  345. update_stats(stats_key)
  346. update_stats(model_stats_key)
  347. local function check_channel_error()
  348. local already_banned = redis.call("EXISTS", banned_key) == 1
  349. if try_ban == 1 and can_ban == 1 then
  350. if already_banned then
  351. return 2
  352. end
  353. redis.call("SET", banned_key, 1)
  354. redis.call("PEXPIRE", banned_key, banExpiry)
  355. return 1
  356. end
  357. local total_req, total_err = get_clean_req_err(stats_key)
  358. if total_req < 20 then
  359. return 0
  360. end
  361. if (total_err / total_req) < max_error_rate then
  362. return 0
  363. else
  364. if can_ban == 0 or already_banned then
  365. return 3
  366. end
  367. redis.call("SET", banned_key, 1)
  368. redis.call("PEXPIRE", banned_key, banExpiry)
  369. return 1
  370. end
  371. end
  372. return check_channel_error()
  373. `
  374. getErrorRateLuaScript = `
  375. local stats_key = KEYS[1]
  376. local now_ts = tonumber(ARGV[1])
  377. local maxSliceCount = 12
  378. local current_slice = math.floor(now_ts / 10 / 1000)
  379. local function parse_req_err(value)
  380. if not value then return 0, 0 end
  381. local r, e = value:match("^(%d+):(%d+)$")
  382. return tonumber(r) or 0, tonumber(e) or 0
  383. end
  384. local function get_clean_req_err(key)
  385. local total_req, total_err = 0, 0
  386. local min_valid_slice = current_slice - maxSliceCount
  387. local all_slices = redis.call("HGETALL", key)
  388. for i = 1, #all_slices, 2 do
  389. local slice = tonumber(all_slices[i])
  390. if slice < min_valid_slice then
  391. redis.call("HDEL", key, all_slices[i])
  392. else
  393. local req, err = parse_req_err(all_slices[i+1])
  394. total_req = total_req + req
  395. total_err = total_err + err
  396. end
  397. end
  398. return total_req, total_err
  399. end
  400. local total_req, total_err = get_clean_req_err(stats_key)
  401. if total_req < 20 then return 0 end
  402. return string.format("%.2f", total_err / total_req)
  403. `
  404. clearChannelModelErrorsLuaScript = `
  405. local model = KEYS[1]
  406. local channel_id = ARGV[1]
  407. local stats_key = "model:" .. model .. ":channel:" .. channel_id .. ":stats"
  408. local banned_key = "model:" .. model .. ":channel:" .. channel_id .. ":banned"
  409. redis.call("DEL", stats_key)
  410. redis.call("DEL", banned_key)
  411. return redis.status_reply("ok")
  412. `
  413. clearChannelAllModelErrorsLuaScript = `
  414. local function del_keys(pattern)
  415. local keys = redis.call("KEYS", pattern)
  416. if #keys > 0 then redis.call("DEL", unpack(keys)) end
  417. end
  418. local channel_id = ARGV[1]
  419. local stats_pattern = "model:*:channel:" .. channel_id .. ":stats"
  420. local banned_pattern = "model:*:channel:" .. channel_id .. ":banned"
  421. del_keys(stats_pattern)
  422. del_keys(banned_pattern)
  423. return redis.status_reply("ok")
  424. `
  425. clearAllModelErrorsLuaScript = `
  426. local function del_keys(pattern)
  427. local keys = redis.call("KEYS", pattern)
  428. if #keys > 0 then redis.call("DEL", unpack(keys)) end
  429. end
  430. del_keys("model:*:channel:*:stats")
  431. del_keys("model:*:channel:*:banned")
  432. return redis.status_reply("ok")
  433. `
  434. )