| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580 |
- package monitor
- import (
- "context"
- "fmt"
- "strconv"
- "strings"
- "time"
- "github.com/labring/aiproxy/core/common"
- "github.com/labring/aiproxy/core/common/config"
- "github.com/redis/go-redis/v9"
- )
- // Redis key prefixes and patterns
- const (
- bannedKeySuffix = ":banned"
- statsKeySuffix = ":stats"
- modelTotalStatsSuffix = ":total_stats"
- channelKeyPart = ":channel:"
- )
- func modelKeyPrefix() string {
- return common.RedisKey("model:")
- }
- // Redis scripts
- var (
- addRequestScript = redis.NewScript(addRequestLuaScript)
- getErrorRateScript = redis.NewScript(getErrorRateLuaScript)
- clearChannelModelErrorsScript = redis.NewScript(clearChannelModelErrorsLuaScript)
- clearChannelAllModelErrorsScript = redis.NewScript(clearChannelAllModelErrorsLuaScript)
- clearAllModelErrorsScript = redis.NewScript(clearAllModelErrorsLuaScript)
- )
- func GetModelsErrorRate(ctx context.Context) (map[string]float64, error) {
- if !common.RedisEnabled {
- return memModelMonitor.GetModelsErrorRate(ctx)
- }
- result := make(map[string]float64)
- pattern := modelKeyPrefix() + "*" + modelTotalStatsSuffix
- now := time.Now().UnixMilli()
- iter := common.RDB.Scan(ctx, 0, pattern, 0).Iterator()
- for iter.Next(ctx) {
- key := iter.Val()
- model := strings.TrimPrefix(key, modelKeyPrefix())
- model = strings.TrimSuffix(model, modelTotalStatsSuffix)
- rate, err := getErrorRateScript.Run(
- ctx,
- common.RDB,
- []string{key},
- now,
- ).Float64()
- if err != nil {
- return nil, err
- }
- result[model] = rate
- }
- if err := iter.Err(); err != nil {
- return nil, err
- }
- return result, nil
- }
- // AddRequest adds a request record and checks if channel should be banned
- // warnErrorRate: threshold for warning (default 30%)
- // maxErrorRate: threshold for banning (0 means no banning)
- func AddRequest(
- ctx context.Context,
- model string,
- channelID int64,
- isError, tryBan bool,
- warnErrorRate,
- maxErrorRate float64,
- ) (beyondThreshold, banExecution bool, err error) {
- // Set default warning threshold if not specified
- if warnErrorRate <= 0 {
- warnErrorRate = config.GetDefaultWarnNotifyErrorRate()
- }
- if !common.RedisEnabled {
- beyondThreshold, banExecution = memModelMonitor.AddRequest(
- model,
- channelID,
- isError,
- tryBan,
- warnErrorRate,
- maxErrorRate,
- )
- return beyondThreshold, banExecution, nil
- }
- errorFlag := 0
- if isError {
- errorFlag = 1
- } else {
- tryBan = false
- }
- now := time.Now().UnixMilli()
- val, err := addRequestScript.Run(
- ctx,
- common.RDB,
- []string{common.RedisKeyPrefix(), model},
- channelID,
- errorFlag,
- now,
- warnErrorRate,
- maxErrorRate,
- maxErrorRate > 0,
- tryBan,
- ).Int64()
- if err != nil {
- return false, false, err
- }
- return val == 3, val == 1, nil
- }
- func buildStatsKey(model, channelID string) string {
- return fmt.Sprintf(
- "%s%s%s%v%s",
- modelKeyPrefix(),
- model,
- channelKeyPart,
- channelID,
- statsKeySuffix,
- )
- }
- func getModelChannelID(key string) (string, int64, bool) {
- content := strings.TrimPrefix(key, modelKeyPrefix())
- content = strings.TrimSuffix(content, statsKeySuffix)
- model, channelIDStr, ok := strings.Cut(content, channelKeyPart)
- if !ok {
- return "", 0, false
- }
- channelID, err := strconv.ParseInt(channelIDStr, 10, 64)
- if err != nil {
- return "", 0, false
- }
- return model, channelID, true
- }
- // GetChannelModelErrorRates gets error rates for a specific channel
- func GetChannelModelErrorRates(ctx context.Context, channelID int64) (map[string]float64, error) {
- if !common.RedisEnabled {
- return memModelMonitor.GetChannelModelErrorRates(ctx, channelID)
- }
- result := make(map[string]float64)
- pattern := buildStatsKey("*", strconv.FormatInt(channelID, 10))
- now := time.Now().UnixMilli()
- iter := common.RDB.Scan(ctx, 0, pattern, 0).Iterator()
- for iter.Next(ctx) {
- key := iter.Val()
- model, _, ok := getModelChannelID(key)
- if !ok {
- continue
- }
- rate, err := getErrorRateScript.Run(
- ctx,
- common.RDB,
- []string{key},
- now,
- ).Float64()
- if err != nil {
- return nil, err
- }
- result[model] = rate
- }
- if err := iter.Err(); err != nil {
- return nil, err
- }
- return result, nil
- }
- func GetModelChannelErrorRate(ctx context.Context, model string) (map[int64]float64, error) {
- if !common.RedisEnabled {
- return memModelMonitor.GetModelChannelErrorRate(ctx, model)
- }
- result := make(map[int64]float64)
- pattern := buildStatsKey(model, "*")
- now := time.Now().UnixMilli()
- iter := common.RDB.Scan(ctx, 0, pattern, 0).Iterator()
- for iter.Next(ctx) {
- key := iter.Val()
- _, channelID, ok := getModelChannelID(key)
- if !ok {
- continue
- }
- rate, err := getErrorRateScript.Run(
- ctx,
- common.RDB,
- []string{key},
- now,
- ).Float64()
- if err != nil {
- return nil, err
- }
- result[channelID] = rate
- }
- if err := iter.Err(); err != nil {
- return nil, err
- }
- return result, nil
- }
- // GetBannedChannelsWithModel gets banned channels for a specific model
- func GetBannedChannelsWithModel(ctx context.Context, model string) ([]int64, error) {
- if !common.RedisEnabled {
- return memModelMonitor.GetBannedChannelsWithModel(ctx, model)
- }
- result := []int64{}
- prefix := modelKeyPrefix() + model + channelKeyPart
- pattern := prefix + "*" + bannedKeySuffix
- iter := common.RDB.Scan(ctx, 0, pattern, 0).Iterator()
- for iter.Next(ctx) {
- key := iter.Val()
- channelIDStr := strings.TrimSuffix(strings.TrimPrefix(key, prefix), bannedKeySuffix)
- channelID, err := strconv.ParseInt(channelIDStr, 10, 64)
- if err != nil {
- continue
- }
- result = append(result, channelID)
- }
- if err := iter.Err(); err != nil {
- return nil, err
- }
- return result, nil
- }
- // GetBannedChannelsMapWithModel gets banned channels for a specific model as a map for efficient lookups
- func GetBannedChannelsMapWithModel(ctx context.Context, model string) (map[int64]struct{}, error) {
- if !common.RedisEnabled {
- return memModelMonitor.GetBannedChannelsMapWithModel(ctx, model)
- }
- result := make(map[int64]struct{})
- prefix := modelKeyPrefix() + model + channelKeyPart
- pattern := prefix + "*" + bannedKeySuffix
- iter := common.RDB.Scan(ctx, 0, pattern, 0).Iterator()
- for iter.Next(ctx) {
- key := iter.Val()
- channelIDStr := strings.TrimSuffix(strings.TrimPrefix(key, prefix), bannedKeySuffix)
- channelID, err := strconv.ParseInt(channelIDStr, 10, 64)
- if err != nil {
- continue
- }
- result[channelID] = struct{}{}
- }
- if err := iter.Err(); err != nil {
- return nil, err
- }
- return result, nil
- }
- // ClearChannelModelErrors clears errors for a specific channel and model
- func ClearChannelModelErrors(ctx context.Context, model string, channelID int) error {
- if !common.RedisEnabled {
- return memModelMonitor.ClearChannelModelErrors(ctx, model, channelID)
- }
- return clearChannelModelErrorsScript.Run(
- ctx,
- common.RDB,
- []string{common.RedisKeyPrefix(), model},
- strconv.Itoa(channelID),
- ).Err()
- }
- // ClearChannelAllModelErrors clears all errors for a specific channel
- func ClearChannelAllModelErrors(ctx context.Context, channelID int) error {
- if !common.RedisEnabled {
- return memModelMonitor.ClearChannelAllModelErrors(ctx, channelID)
- }
- return clearChannelAllModelErrorsScript.Run(
- ctx,
- common.RDB,
- []string{common.RedisKeyPrefix()},
- strconv.Itoa(channelID),
- ).Err()
- }
- // ClearAllModelErrors clears all error records
- func ClearAllModelErrors(ctx context.Context) error {
- if !common.RedisEnabled {
- return memModelMonitor.ClearAllModelErrors(ctx)
- }
- return clearAllModelErrorsScript.Run(ctx, common.RDB, []string{common.RedisKeyPrefix()}).Err()
- }
- // GetAllBannedModelChannels gets all banned channels for all models
- func GetAllBannedModelChannels(ctx context.Context) (map[string][]int64, error) {
- if !common.RedisEnabled {
- return memModelMonitor.GetAllBannedModelChannels(ctx)
- }
- result := make(map[string][]int64)
- pattern := modelKeyPrefix() + "*" + channelKeyPart + "*" + bannedKeySuffix
- iter := common.RDB.Scan(ctx, 0, pattern, 0).Iterator()
- for iter.Next(ctx) {
- key := iter.Val()
- parts := strings.TrimPrefix(key, modelKeyPrefix())
- parts = strings.TrimSuffix(parts, bannedKeySuffix)
- model, channelIDStr, ok := strings.Cut(parts, channelKeyPart)
- if !ok {
- continue
- }
- channelID, err := strconv.ParseInt(channelIDStr, 10, 64)
- if err != nil {
- continue
- }
- if _, exists := result[model]; !exists {
- result[model] = []int64{}
- }
- result[model] = append(result[model], channelID)
- }
- if err := iter.Err(); err != nil {
- return nil, err
- }
- return result, nil
- }
- // GetAllChannelModelErrorRates gets error rates for all channels and models
- func GetAllChannelModelErrorRates(ctx context.Context) (map[int64]map[string]float64, error) {
- if !common.RedisEnabled {
- return memModelMonitor.GetAllChannelModelErrorRates(ctx)
- }
- result := make(map[int64]map[string]float64)
- pattern := buildStatsKey("*", "*")
- now := time.Now().UnixMilli()
- iter := common.RDB.Scan(ctx, 0, pattern, 0).Iterator()
- for iter.Next(ctx) {
- key := iter.Val()
- model, channelID, ok := getModelChannelID(key)
- if !ok {
- continue
- }
- rate, err := getErrorRateScript.Run(
- ctx,
- common.RDB,
- []string{key},
- now,
- ).Float64()
- if err != nil {
- return nil, err
- }
- if _, exists := result[channelID]; !exists {
- result[channelID] = make(map[string]float64)
- }
- result[channelID][model] = rate
- }
- if err := iter.Err(); err != nil {
- return nil, err
- }
- return result, nil
- }
- // Lua scripts
- const (
- addRequestLuaScript = `
- local prefix = KEYS[1]
- local model = KEYS[2]
- local channel_id = ARGV[1]
- local is_error = tonumber(ARGV[2])
- local now_ts = tonumber(ARGV[3])
- local warn_error_rate = tonumber(ARGV[4])
- local max_error_rate = tonumber(ARGV[5])
- local can_ban = tonumber(ARGV[6])
- local try_ban = tonumber(ARGV[7])
- local banned_key = prefix .. ":model:" .. model .. ":channel:" .. channel_id .. ":banned"
- local stats_key = prefix .. ":model:" .. model .. ":channel:" .. channel_id .. ":stats"
- local model_stats_key = prefix .. ":model:" .. model .. ":total_stats"
- local maxSliceCount = 12
- local statsExpiry = maxSliceCount * 10 * 1000
- local banExpiry = 5 * 60 * 1000
- local current_slice = math.floor(now_ts / 10 / 1000)
- local function parse_req_err(value)
- if not value then return 0, 0 end
- local r, e = value:match("^(%d+):(%d+)$")
- return tonumber(r) or 0, tonumber(e) or 0
- end
- local function update_stats(key)
- local req, err = parse_req_err(redis.call("HGET", key, current_slice))
- req = req + 1
- err = err + (is_error == 1 and 1 or 0)
- redis.call("HSET", key, current_slice, req .. ":" .. err)
- redis.call("PEXPIRE", key, statsExpiry)
- return req, err
- end
- local function get_clean_req_err(key)
- local total_req, total_err = 0, 0
- local min_valid_slice = current_slice - maxSliceCount
- local all_slices = redis.call("HGETALL", key)
- for i = 1, #all_slices, 2 do
- local slice = tonumber(all_slices[i])
- if slice < min_valid_slice then
- redis.call("HDEL", key, all_slices[i])
- else
- local req, err = parse_req_err(all_slices[i+1])
- total_req = total_req + req
- total_err = total_err + err
- end
- end
- return total_req, total_err
- end
- update_stats(stats_key)
- update_stats(model_stats_key)
- local function check_channel_error()
- local already_banned = redis.call("EXISTS", banned_key) == 1
- if try_ban == 1 and can_ban == 1 then
- if already_banned then
- return 2
- end
- redis.call("SET", banned_key, 1)
- redis.call("PEXPIRE", banned_key, banExpiry)
- return 1
- end
- local total_req, total_err = get_clean_req_err(stats_key)
- if total_req < 20 then
- return 0
- end
- local error_rate = total_err / total_req
-
- -- Check if we should ban (only if max_error_rate is set and exceeded)
- if can_ban == 1 and error_rate >= max_error_rate then
- if already_banned then
- return 3 -- Beyond threshold but already banned
- end
- redis.call("SET", banned_key, 1)
- redis.call("PEXPIRE", banned_key, banExpiry)
- return 1 -- Ban executed
- elseif error_rate >= warn_error_rate then
- return 3 -- Beyond warning threshold but not banning
- else
- return 0 -- All good
- end
- end
- return check_channel_error()
- `
- getErrorRateLuaScript = `
- local stats_key = KEYS[1]
- local now_ts = tonumber(ARGV[1])
- local maxSliceCount = 12
- local current_slice = math.floor(now_ts / 10 / 1000)
- local function parse_req_err(value)
- if not value then return 0, 0 end
- local r, e = value:match("^(%d+):(%d+)$")
- return tonumber(r) or 0, tonumber(e) or 0
- end
- local function get_clean_req_err(key)
- local total_req, total_err = 0, 0
- local min_valid_slice = current_slice - maxSliceCount
- local all_slices = redis.call("HGETALL", key)
- for i = 1, #all_slices, 2 do
- local slice = tonumber(all_slices[i])
- if slice < min_valid_slice then
- redis.call("HDEL", key, all_slices[i])
- else
- local req, err = parse_req_err(all_slices[i+1])
- total_req = total_req + req
- total_err = total_err + err
- end
- end
- return total_req, total_err
- end
- local total_req, total_err = get_clean_req_err(stats_key)
- if total_req < 20 then return 0 end
- return string.format("%.2f", total_err / total_req)
- `
- clearChannelModelErrorsLuaScript = `
- local prefix = KEYS[1]
- local model = KEYS[2]
- local channel_id = ARGV[1]
- local stats_key = prefix .. ":model:" .. model .. ":channel:" .. channel_id .. ":stats"
- local banned_key = prefix .. ":model:" .. model .. ":channel:" .. channel_id .. ":banned"
- redis.call("DEL", stats_key)
- redis.call("DEL", banned_key)
- return redis.status_reply("ok")
- `
- clearChannelAllModelErrorsLuaScript = `
- local prefix = KEYS[1]
- local function del_keys(pattern)
- local keys = redis.call("KEYS", pattern)
- if #keys > 0 then redis.call("DEL", unpack(keys)) end
- end
- local channel_id = ARGV[1]
- local stats_pattern = prefix .. ":model:*:channel:" .. channel_id .. ":stats"
- local banned_pattern = prefix .. ":model:*:channel:" .. channel_id .. ":banned"
- del_keys(stats_pattern)
- del_keys(banned_pattern)
- return redis.status_reply("ok")
- `
- clearAllModelErrorsLuaScript = `
- local prefix = KEYS[1]
- local function del_keys(pattern)
- local keys = redis.call("KEYS", pattern)
- if #keys > 0 then redis.call("DEL", unpack(keys)) end
- end
- del_keys(prefix .. ":model:*:channel:*:stats")
- del_keys(prefix .. ":model:*:channel:*:banned")
- return redis.status_reply("ok")
- `
- )
|