Browse Source

feat: check group usage and alert (#390)

* feat: check group usage and alert

* fix: ci lint

* fix: use three day usage check ratio
zijiren 6 months ago
parent
commit
8a2bc193fb
6 changed files with 655 additions and 297 deletions
  1. 32 0
      core/common/config/config.go
  2. 10 297
      core/main.go
  3. 35 0
      core/model/option.go
  4. 159 0
      core/model/usage_alert.go
  5. 165 0
      core/startup.go
  6. 254 0
      core/task/task.go

+ 32 - 0
core/common/config/config.go

@@ -26,6 +26,9 @@ var (
 	defaultChannelModelMapping   atomic.Value
 	groupMaxTokenNum             atomic.Int64
 	groupConsumeLevelRatio       atomic.Value
+	usageAlertThreshold          atomic.Int64 // default 0 means disabled
+	usageAlertWhitelist          atomic.Value
+	usageAlertMinAvgThreshold    atomic.Int64 // 前三天平均用量最低阈值,default 0 means no limit
 
 	defaultWarnNotifyErrorRate uint64 = math.Float64bits(0.5)
 
@@ -38,6 +41,7 @@ func init() {
 	defaultChannelModels.Store(make(map[int][]string))
 	defaultChannelModelMapping.Store(make(map[int]map[string]string))
 	groupConsumeLevelRatio.Store(make(map[float64]float64))
+	usageAlertWhitelist.Store(make([]string, 0))
 	notifyNote.Store("")
 	defaultMCPHost.Store("")
 	publicMCPHost.Store("")
@@ -247,3 +251,31 @@ func SetDefaultWarnNotifyErrorRate(rate float64) {
 	rate = env.Float64("DEFAULT_WARN_NOTIFY_ERROR_RATE", rate)
 	atomic.StoreUint64(&defaultWarnNotifyErrorRate, math.Float64bits(rate))
 }
+
+func GetUsageAlertThreshold() int64 {
+	return usageAlertThreshold.Load()
+}
+
+func SetUsageAlertThreshold(threshold int64) {
+	threshold = env.Int64("USAGE_ALERT_THRESHOLD", threshold)
+	usageAlertThreshold.Store(threshold)
+}
+
+func GetUsageAlertWhitelist() []string {
+	w, _ := usageAlertWhitelist.Load().([]string)
+	return w
+}
+
+func SetUsageAlertWhitelist(whitelist []string) {
+	whitelist = env.JSON("USAGE_ALERT_WHITELIST", whitelist)
+	usageAlertWhitelist.Store(whitelist)
+}
+
+func GetUsageAlertMinAvgThreshold() int64 {
+	return usageAlertMinAvgThreshold.Load()
+}
+
+func SetUsageAlertMinAvgThreshold(threshold int64) {
+	threshold = env.Int64("USAGE_ALERT_MIN_AVG_THRESHOLD", threshold)
+	usageAlertMinAvgThreshold.Store(threshold)
+}

+ 10 - 297
core/main.go

@@ -2,36 +2,19 @@ package main
 
 import (
 	"context"
-	"crypto/sha256"
-	"encoding/hex"
-	"errors"
 	"flag"
-	"fmt"
-	"net/http"
 	"os"
 	"os/signal"
-	"path/filepath"
-	"slices"
 	"sync"
 	"syscall"
 	"time"
 
-	"github.com/bytedance/sonic"
-	"github.com/gin-gonic/gin"
-	"github.com/joho/godotenv"
 	"github.com/labring/aiproxy/core/common"
-	"github.com/labring/aiproxy/core/common/balance"
 	"github.com/labring/aiproxy/core/common/config"
 	"github.com/labring/aiproxy/core/common/consume"
-	"github.com/labring/aiproxy/core/common/conv"
-	"github.com/labring/aiproxy/core/common/ipblack"
-	"github.com/labring/aiproxy/core/common/notify"
-	"github.com/labring/aiproxy/core/common/pprof"
-	"github.com/labring/aiproxy/core/common/trylock"
 	"github.com/labring/aiproxy/core/controller"
-	"github.com/labring/aiproxy/core/middleware"
 	"github.com/labring/aiproxy/core/model"
-	"github.com/labring/aiproxy/core/router"
+	"github.com/labring/aiproxy/core/task"
 	log "github.com/sirupsen/logrus"
 )
 
@@ -45,280 +28,6 @@ func init() {
 	flag.IntVar(&pprofPort, "pprof-port", 15000, "pport http server port")
 }
 
-func initializeServices() error {
-	initializePprof()
-	initializeNotifier()
-
-	if err := common.InitRedisClient(); err != nil {
-		return err
-	}
-
-	if err := initializeBalance(); err != nil {
-		return err
-	}
-
-	if err := model.InitDB(); err != nil {
-		return err
-	}
-
-	if err := initializeOptionAndCaches(); err != nil {
-		return err
-	}
-
-	return model.InitLogDB(int(config.GetCleanLogBatchSize()))
-}
-
-func initializePprof() {
-	go func() {
-		err := pprof.RunPprofServer(pprofPort)
-		if err != nil {
-			log.Errorf("run pprof server error: %v", err)
-		}
-	}()
-}
-
-func initializeBalance() error {
-	sealosJwtKey := os.Getenv("SEALOS_JWT_KEY")
-	if sealosJwtKey == "" {
-		log.Info("SEALOS_JWT_KEY is not set, balance will not be enabled")
-		return nil
-	}
-
-	log.Info("SEALOS_JWT_KEY is set, balance will be enabled")
-
-	return balance.InitSealos(sealosJwtKey, os.Getenv("SEALOS_ACCOUNT_URL"))
-}
-
-func initializeNotifier() {
-	feishuWh := os.Getenv("NOTIFY_FEISHU_WEBHOOK")
-	if feishuWh != "" {
-		notify.SetDefaultNotifier(notify.NewFeishuNotify(feishuWh))
-		log.Info("NOTIFY_FEISHU_WEBHOOK is set, notifier will be use feishu")
-	}
-}
-
-func initializeOptionAndCaches() error {
-	log.Info("starting init config and channel")
-
-	if err := model.InitOption2DB(); err != nil {
-		return err
-	}
-
-	return model.InitModelConfigAndChannelCache()
-}
-
-func startSyncServices(ctx context.Context, wg *sync.WaitGroup) {
-	wg.Add(2)
-
-	go model.SyncOptions(ctx, wg, time.Second*5)
-	go model.SyncModelConfigAndChannelCache(ctx, wg, time.Second*10)
-}
-
-func setupHTTPServer() (*http.Server, *gin.Engine) {
-	server := gin.New()
-
-	server.Use(
-		middleware.GinRecoveryHandler,
-		middleware.NewLog(log.StandardLogger()),
-		middleware.RequestIDMiddleware,
-		middleware.CORS(),
-	)
-	router.SetRouter(server)
-
-	listenEnv := os.Getenv("LISTEN")
-	if listenEnv != "" {
-		listen = listenEnv
-	}
-
-	return &http.Server{
-		Addr:              listen,
-		ReadHeaderTimeout: 10 * time.Second,
-		Handler:           server,
-	}, server
-}
-
-func autoTestBannedModels(ctx context.Context) {
-	ticker := time.NewTicker(time.Second * 30)
-	defer ticker.Stop()
-
-	for {
-		select {
-		case <-ctx.Done():
-			return
-		case <-ticker.C:
-			controller.AutoTestBannedModels()
-		}
-	}
-}
-
-func detectIPGroupsTask(ctx context.Context) {
-	ticker := time.NewTicker(time.Minute)
-	defer ticker.Stop()
-
-	for {
-		select {
-		case <-ctx.Done():
-			return
-		case <-ticker.C:
-			if !trylock.Lock("detectIPGroups", time.Minute) {
-				continue
-			}
-
-			detectIPGroups()
-		}
-	}
-}
-
-func detectIPGroups() {
-	threshold := config.GetIPGroupsThreshold()
-	if threshold < 1 {
-		return
-	}
-
-	ipGroupList, err := model.GetIPGroups(int(threshold), time.Now().Add(-time.Hour), time.Now())
-	if err != nil {
-		notify.ErrorThrottle("detectIPGroups", time.Minute, "detect IP groups failed", err.Error())
-	}
-
-	if len(ipGroupList) == 0 {
-		return
-	}
-
-	banThreshold := config.GetIPGroupsBanThreshold()
-	for ip, groups := range ipGroupList {
-		slices.Sort(groups)
-
-		groupsJSON, err := sonic.MarshalString(groups)
-		if err != nil {
-			notify.ErrorThrottle(
-				"detectIPGroupsMarshal",
-				time.Minute,
-				"marshal IP groups failed",
-				err.Error(),
-			)
-
-			continue
-		}
-
-		if banThreshold >= threshold && len(groups) >= int(banThreshold) {
-			rowsAffected, err := model.UpdateGroupsStatus(groups, model.GroupStatusDisabled)
-			if err != nil {
-				notify.ErrorThrottle(
-					"detectIPGroupsBan",
-					time.Minute,
-					"update groups status failed",
-					err.Error(),
-				)
-			}
-
-			if rowsAffected > 0 {
-				notify.Warn(
-					fmt.Sprintf(
-						"Suspicious activity: IP %s is using %d groups (exceeds ban threshold of %d). IP and all groups have been disabled.",
-						ip,
-						len(groups),
-						banThreshold,
-					),
-					groupsJSON,
-				)
-				ipblack.SetIPBlackAnyWay(ip, time.Hour*48)
-			}
-
-			continue
-		}
-
-		h := sha256.New()
-		h.Write(conv.StringToBytes(groupsJSON))
-		groupsHash := hex.EncodeToString(h.Sum(nil))
-		hashKey := fmt.Sprintf("%s:%s", ip, groupsHash)
-
-		notify.WarnThrottle(
-			hashKey,
-			time.Hour*3,
-			fmt.Sprintf(
-				"Potential abuse: IP %s is using %d groups (exceeds threshold of %d)",
-				ip,
-				len(groups),
-				threshold,
-			),
-			groupsJSON,
-		)
-	}
-}
-
-func cleanLog(ctx context.Context) {
-	// the interval should not be too large to avoid cleaning too much at once
-	ticker := time.NewTicker(time.Second * 5)
-	defer ticker.Stop()
-
-	for {
-		select {
-		case <-ctx.Done():
-			return
-		case <-ticker.C:
-			if !trylock.Lock("cleanLog", time.Second*5) {
-				continue
-			}
-
-			optimize := trylock.Lock("optimizeLog", time.Hour*24)
-
-			err := model.CleanLog(int(config.GetCleanLogBatchSize()), optimize)
-			if err != nil {
-				notify.ErrorThrottle("cleanLog", time.Minute*5, "clean log failed", err.Error())
-			}
-		}
-	}
-}
-
-var loadedEnvFiles []string
-
-func loadEnv() {
-	envfiles := []string{
-		".env",
-		".env.local",
-	}
-	for _, envfile := range envfiles {
-		absPath, err := filepath.Abs(envfile)
-		if err != nil {
-			panic(
-				fmt.Sprintf(
-					"failed to get absolute path of env file: %s, error: %s",
-					envfile,
-					err.Error(),
-				),
-			)
-		}
-
-		file, err := os.Stat(absPath)
-		if err != nil {
-			continue
-		}
-
-		if file.IsDir() {
-			continue
-		}
-
-		if err := godotenv.Overload(absPath); err != nil {
-			panic(fmt.Sprintf("failed to load env file: %s, error: %s", absPath, err.Error()))
-		}
-
-		loadedEnvFiles = append(loadedEnvFiles, absPath)
-	}
-}
-
-func printLoadedEnvFiles() {
-	for _, envfile := range loadedEnvFiles {
-		log.Infof("loaded env file: %s", envfile)
-	}
-}
-
-func listenAndServe(srv *http.Server) {
-	if err := srv.ListenAndServe(); err != nil &&
-		!errors.Is(err, http.ErrServerClosed) {
-		log.Fatal("failed to start HTTP server: " + err.Error())
-	}
-}
-
 // Swagger godoc
 //
 //	@title						AI Proxy Swagger API
@@ -337,7 +46,7 @@ func main() {
 
 	printLoadedEnvFiles()
 
-	if err := initializeServices(); err != nil {
+	if err := initializeServices(pprofPort); err != nil {
 		log.Fatal("failed to initialize services: " + err.Error())
 	}
 
@@ -353,19 +62,23 @@ func main() {
 	var wg sync.WaitGroup
 	startSyncServices(ctx, &wg)
 
-	srv, _ := setupHTTPServer()
+	srv, _ := setupHTTPServer(listen)
 
 	log.Info("auto test banned models task started")
 
-	go autoTestBannedModels(ctx)
+	go task.AutoTestBannedModelsTask(ctx)
 
 	log.Info("clean log task started")
 
-	go cleanLog(ctx)
+	go task.CleanLogTask(ctx)
 
 	log.Info("detect ip groups task started")
 
-	go detectIPGroupsTask(ctx)
+	go task.DetectIPGroupsTask(ctx)
+
+	log.Info("usage alert task started")
+
+	go task.UsageAlertTask(ctx)
 
 	log.Info("update channels balance task started")
 

+ 35 - 0
core/model/option.go

@@ -111,6 +111,18 @@ func initOptionMap() error {
 		-1,
 		64,
 	)
+	optionMap["UsageAlertThreshold"] = strconv.FormatInt(config.GetUsageAlertThreshold(), 10)
+
+	usageAlertWhitelistJSON, err := sonic.Marshal(config.GetUsageAlertWhitelist())
+	if err != nil {
+		return err
+	}
+
+	optionMap["UsageAlertWhitelist"] = conv.BytesToString(usageAlertWhitelistJSON)
+	optionMap["UsageAlertMinAvgThreshold"] = strconv.FormatInt(
+		config.GetUsageAlertMinAvgThreshold(),
+		10,
+	)
 
 	optionKeys = make([]string, 0, len(optionMap))
 	for key := range optionMap {
@@ -410,6 +422,29 @@ func updateOption(key, value string, isInit bool) (err error) {
 		}
 
 		config.SetDefaultWarnNotifyErrorRate(rate)
+	case "UsageAlertThreshold":
+		threshold, err := strconv.ParseInt(value, 10, 64)
+		if err != nil {
+			return err
+		}
+
+		config.SetUsageAlertThreshold(threshold)
+	case "UsageAlertWhitelist":
+		var whitelist []string
+
+		err := sonic.Unmarshal(conv.StringToBytes(value), &whitelist)
+		if err != nil {
+			return err
+		}
+
+		config.SetUsageAlertWhitelist(whitelist)
+	case "UsageAlertMinAvgThreshold":
+		threshold, err := strconv.ParseInt(value, 10, 64)
+		if err != nil {
+			return err
+		}
+
+		config.SetUsageAlertMinAvgThreshold(threshold)
 	default:
 		return ErrUnknownOptionKey
 	}

+ 159 - 0
core/model/usage_alert.go

@@ -0,0 +1,159 @@
+package model
+
+import (
+	"time"
+)
+
+// GroupUsageAlertItem 用量告警项
+type GroupUsageAlertItem struct {
+	GroupID           string
+	ThreeDayAvgAmount float64 // 前三天的平均用量
+	TodayAmount       float64
+	Ratio             float64
+}
+
+// calculateSpikeThreshold 根据前三天平均用量计算动态告警倍率
+// 用量越大,倍率越低,避免误报;用量越小,倍率越高,捕获异常
+func calculateSpikeThreshold(avgAmount float64) float64 {
+	switch {
+	case avgAmount < 100:
+		return 5.0
+	case avgAmount < 300:
+		return 4.0
+	case avgAmount < 1000:
+		return 3.0
+	case avgAmount < 2000:
+		return 2.5
+	case avgAmount < 5000:
+		return 2.0
+	default:
+		return 1.5
+	}
+}
+
+// GetGroupUsageAlert 获取用量突升异常的用户
+// 新的检测逻辑:
+// 1. 基准阈值:当天用量必须 >= threshold(如 100)才开始检测
+// 2. 动态倍率:根据前三天平均用量分段计算告警倍率
+//   - 前三天平均用量 < 100:倍率 5.0(小用户突增 5 倍很异常)
+//   - 前三天平均用量 [100, 300):倍率 4.0
+//   - 前三天平均用量 [300, 1000):倍率 3.0
+//   - 前三天平均用量 [1000, 2000):倍率 2.5
+//   - 前三天平均用量 [2000, 5000):倍率 2.0
+//   - 前三天平均用量 >= 5000:倍率 1.5(大用户增长 1.5 倍就值得关注)
+//
+// 3. 前三天平均用量必须 >= minAvgThreshold(如 3)才进行检测
+// 4. 不在白名单中
+func GetGroupUsageAlert(
+	threshold, minAvgThreshold float64,
+	whitelist []string,
+) ([]GroupUsageAlertItem, error) {
+	now := time.Now()
+
+	// 计算当天的时间范围(0点到当前时间)
+	todayStart := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
+	todayEnd := now
+
+	// 计算前三天的时间范围(从3天前的0点到昨天的23:59:59)
+	threeDaysAgoStart := todayStart.Add(-72 * time.Hour) // 3天前的0点
+	yesterdayEnd := todayStart.Add(-time.Second)         // 昨天的23:59:59
+
+	// 查询当天用量达到阈值的用户(第一步筛选)
+	type TodayUsage struct {
+		GroupID    string
+		UsedAmount float64
+	}
+
+	var todayUsages []TodayUsage
+
+	err := LogDB.
+		Model(&GroupSummary{}).
+		Select("group_id, SUM(used_amount) as used_amount").
+		Where("hour_timestamp BETWEEN ? AND ?", todayStart.Unix(), todayEnd.Unix()).
+		Group("group_id").
+		Having("SUM(used_amount) >= ?", threshold).
+		Find(&todayUsages).Error
+	if err != nil {
+		return nil, err
+	}
+
+	if len(todayUsages) == 0 {
+		return nil, nil
+	}
+
+	// 提取 group_id 列表
+	groupIDs := make([]string, len(todayUsages))
+
+	todayUsageMap := make(map[string]float64)
+	for i, usage := range todayUsages {
+		groupIDs[i] = usage.GroupID
+		todayUsageMap[usage.GroupID] = usage.UsedAmount
+	}
+
+	// 查询这些用户的前三天用量
+	type ThreeDayUsage struct {
+		GroupID    string
+		UsedAmount float64
+	}
+
+	var threeDayUsages []ThreeDayUsage
+
+	err = LogDB.
+		Model(&GroupSummary{}).
+		Select("group_id, SUM(used_amount) as used_amount").
+		Where("group_id IN ?", groupIDs).
+		Where("hour_timestamp BETWEEN ? AND ?", threeDaysAgoStart.Unix(), yesterdayEnd.Unix()).
+		Group("group_id").
+		Find(&threeDayUsages).Error
+	if err != nil {
+		return nil, err
+	}
+
+	// 构建前三天平均用量映射
+	threeDayAvgUsageMap := make(map[string]float64)
+	for _, usage := range threeDayUsages {
+		// 计算平均值:总用量除以3
+		threeDayAvgUsageMap[usage.GroupID] = usage.UsedAmount / 3.0
+	}
+
+	// 构建白名单映射,用于快速查找
+	whitelistMap := make(map[string]bool)
+	for _, groupID := range whitelist {
+		whitelistMap[groupID] = true
+	}
+
+	// 筛选出符合条件的用户
+	var alerts []GroupUsageAlertItem
+	for groupID, todayAmount := range todayUsageMap {
+		// 跳过白名单中的用户
+		if whitelistMap[groupID] {
+			continue
+		}
+
+		// 获取前三天平均用量(如果没有前三天的数据,默认为 0)
+		threeDayAvgAmount := threeDayAvgUsageMap[groupID]
+
+		// 过滤掉前三天平均用量低于阈值的用户
+		if threeDayAvgAmount <= minAvgThreshold || threeDayAvgAmount == 0 {
+			continue
+		}
+
+		// 计算实际比率
+		ratio := todayAmount / threeDayAvgAmount
+
+		// 根据前三天平均用量计算动态告警倍率
+		requiredRatio := calculateSpikeThreshold(threeDayAvgAmount)
+
+		// 检查是否满足告警条件
+		if ratio >= requiredRatio {
+			alerts = append(alerts, GroupUsageAlertItem{
+				GroupID:           groupID,
+				ThreeDayAvgAmount: threeDayAvgAmount,
+				TodayAmount:       todayAmount,
+				Ratio:             ratio,
+			})
+		}
+	}
+
+	return alerts, nil
+}

+ 165 - 0
core/startup.go

@@ -0,0 +1,165 @@
+package main
+
+import (
+	"context"
+	"errors"
+	"fmt"
+	"net/http"
+	"os"
+	"path/filepath"
+	"sync"
+	"time"
+
+	"github.com/gin-gonic/gin"
+	"github.com/joho/godotenv"
+	"github.com/labring/aiproxy/core/common"
+	"github.com/labring/aiproxy/core/common/balance"
+	"github.com/labring/aiproxy/core/common/config"
+	"github.com/labring/aiproxy/core/common/notify"
+	"github.com/labring/aiproxy/core/common/pprof"
+	"github.com/labring/aiproxy/core/middleware"
+	"github.com/labring/aiproxy/core/model"
+	"github.com/labring/aiproxy/core/router"
+	log "github.com/sirupsen/logrus"
+)
+
+func initializeServices(pprofPort int) error {
+	initializePprof(pprofPort)
+	initializeNotifier()
+
+	if err := common.InitRedisClient(); err != nil {
+		return err
+	}
+
+	if err := initializeBalance(); err != nil {
+		return err
+	}
+
+	if err := model.InitDB(); err != nil {
+		return err
+	}
+
+	if err := initializeOptionAndCaches(); err != nil {
+		return err
+	}
+
+	return model.InitLogDB(int(config.GetCleanLogBatchSize()))
+}
+
+func initializePprof(pprofPort int) {
+	go func() {
+		err := pprof.RunPprofServer(pprofPort)
+		if err != nil {
+			log.Errorf("run pprof server error: %v", err)
+		}
+	}()
+}
+
+func initializeBalance() error {
+	sealosJwtKey := os.Getenv("SEALOS_JWT_KEY")
+	if sealosJwtKey == "" {
+		log.Info("SEALOS_JWT_KEY is not set, balance will not be enabled")
+		return nil
+	}
+
+	log.Info("SEALOS_JWT_KEY is set, balance will be enabled")
+
+	return balance.InitSealos(sealosJwtKey, os.Getenv("SEALOS_ACCOUNT_URL"))
+}
+
+func initializeNotifier() {
+	feishuWh := os.Getenv("NOTIFY_FEISHU_WEBHOOK")
+	if feishuWh != "" {
+		notify.SetDefaultNotifier(notify.NewFeishuNotify(feishuWh))
+		log.Info("NOTIFY_FEISHU_WEBHOOK is set, notifier will be use feishu")
+	}
+}
+
+func initializeOptionAndCaches() error {
+	log.Info("starting init config and channel")
+
+	if err := model.InitOption2DB(); err != nil {
+		return err
+	}
+
+	return model.InitModelConfigAndChannelCache()
+}
+
+func startSyncServices(ctx context.Context, wg *sync.WaitGroup) {
+	wg.Add(2)
+
+	go model.SyncOptions(ctx, wg, time.Second*5)
+	go model.SyncModelConfigAndChannelCache(ctx, wg, time.Second*10)
+}
+
+func setupHTTPServer(listen string) (*http.Server, *gin.Engine) {
+	server := gin.New()
+
+	server.Use(
+		middleware.GinRecoveryHandler,
+		middleware.NewLog(log.StandardLogger()),
+		middleware.RequestIDMiddleware,
+		middleware.CORS(),
+	)
+	router.SetRouter(server)
+
+	listenEnv := os.Getenv("LISTEN")
+	if listenEnv != "" {
+		listen = listenEnv
+	}
+
+	return &http.Server{
+		Addr:              listen,
+		ReadHeaderTimeout: 10 * time.Second,
+		Handler:           server,
+	}, server
+}
+
+var loadedEnvFiles []string
+
+func loadEnv() {
+	envfiles := []string{
+		".env",
+		".env.local",
+	}
+	for _, envfile := range envfiles {
+		absPath, err := filepath.Abs(envfile)
+		if err != nil {
+			panic(
+				fmt.Sprintf(
+					"failed to get absolute path of env file: %s, error: %s",
+					envfile,
+					err.Error(),
+				),
+			)
+		}
+
+		file, err := os.Stat(absPath)
+		if err != nil {
+			continue
+		}
+
+		if file.IsDir() {
+			continue
+		}
+
+		if err := godotenv.Overload(absPath); err != nil {
+			panic(fmt.Sprintf("failed to load env file: %s, error: %s", absPath, err.Error()))
+		}
+
+		loadedEnvFiles = append(loadedEnvFiles, absPath)
+	}
+}
+
+func printLoadedEnvFiles() {
+	for _, envfile := range loadedEnvFiles {
+		log.Infof("loaded env file: %s", envfile)
+	}
+}
+
+func listenAndServe(srv *http.Server) {
+	if err := srv.ListenAndServe(); err != nil &&
+		!errors.Is(err, http.ErrServerClosed) {
+		log.Fatal("failed to start HTTP server: " + err.Error())
+	}
+}

+ 254 - 0
core/task/task.go

@@ -0,0 +1,254 @@
+package task
+
+import (
+	"context"
+	"crypto/sha256"
+	"encoding/hex"
+	"fmt"
+	"slices"
+	"time"
+
+	"github.com/bytedance/sonic"
+	"github.com/labring/aiproxy/core/common/config"
+	"github.com/labring/aiproxy/core/common/conv"
+	"github.com/labring/aiproxy/core/common/ipblack"
+	"github.com/labring/aiproxy/core/common/notify"
+	"github.com/labring/aiproxy/core/common/trylock"
+	"github.com/labring/aiproxy/core/controller"
+	"github.com/labring/aiproxy/core/model"
+)
+
+// AutoTestBannedModelsTask 自动测试被禁用的模型
+func AutoTestBannedModelsTask(ctx context.Context) {
+	ticker := time.NewTicker(time.Second * 30)
+	defer ticker.Stop()
+
+	for {
+		select {
+		case <-ctx.Done():
+			return
+		case <-ticker.C:
+			controller.AutoTestBannedModels()
+		}
+	}
+}
+
+// DetectIPGroupsTask 检测 IP 使用多个 group 的情况
+func DetectIPGroupsTask(ctx context.Context) {
+	ticker := time.NewTicker(time.Minute)
+	defer ticker.Stop()
+
+	for {
+		select {
+		case <-ctx.Done():
+			return
+		case <-ticker.C:
+			if !trylock.Lock("runDetectIPGroups", time.Minute) {
+				continue
+			}
+
+			detectIPGroups()
+		}
+	}
+}
+
+func detectIPGroups() {
+	threshold := config.GetIPGroupsThreshold()
+	if threshold < 1 {
+		return
+	}
+
+	ipGroupList, err := model.GetIPGroups(int(threshold), time.Now().Add(-time.Hour), time.Now())
+	if err != nil {
+		notify.ErrorThrottle("detectIPGroups", time.Minute, "detect IP groups failed", err.Error())
+		return
+	}
+
+	if len(ipGroupList) == 0 {
+		return
+	}
+
+	banThreshold := config.GetIPGroupsBanThreshold()
+	for ip, groups := range ipGroupList {
+		slices.Sort(groups)
+
+		groupsJSON, err := sonic.MarshalString(groups)
+		if err != nil {
+			notify.ErrorThrottle(
+				"detectIPGroupsMarshal",
+				time.Minute,
+				"marshal IP groups failed",
+				err.Error(),
+			)
+
+			continue
+		}
+
+		if banThreshold >= threshold && len(groups) >= int(banThreshold) {
+			rowsAffected, err := model.UpdateGroupsStatus(groups, model.GroupStatusDisabled)
+			if err != nil {
+				notify.ErrorThrottle(
+					"detectIPGroupsBan",
+					time.Minute,
+					"update groups status failed",
+					err.Error(),
+				)
+			}
+
+			if rowsAffected > 0 {
+				notify.Warn(
+					fmt.Sprintf(
+						"Suspicious activity: IP %s is using %d groups (exceeds ban threshold of %d). IP and all groups have been disabled.",
+						ip,
+						len(groups),
+						banThreshold,
+					),
+					groupsJSON,
+				)
+				ipblack.SetIPBlackAnyWay(ip, time.Hour*48)
+			}
+
+			continue
+		}
+
+		h := sha256.New()
+		h.Write(conv.StringToBytes(groupsJSON))
+		groupsHash := hex.EncodeToString(h.Sum(nil))
+		hashKey := fmt.Sprintf("%s:%s", ip, groupsHash)
+
+		notify.WarnThrottle(
+			hashKey,
+			time.Hour*3,
+			fmt.Sprintf(
+				"Potential abuse: IP %s is using %d groups (exceeds threshold of %d)",
+				ip,
+				len(groups),
+				threshold,
+			),
+			groupsJSON,
+		)
+	}
+}
+
+// UsageAlertTask 用量异常告警任务
+func UsageAlertTask(ctx context.Context) {
+	ticker := time.NewTicker(time.Hour)
+	defer ticker.Stop()
+
+	for {
+		select {
+		case <-ctx.Done():
+			return
+		case <-ticker.C:
+			if !trylock.Lock("runUsageAlert", time.Hour) {
+				continue
+			}
+
+			checkUsageAlert()
+		}
+	}
+}
+
+func checkUsageAlert() {
+	threshold := config.GetUsageAlertThreshold()
+	if threshold <= 0 {
+		return
+	}
+
+	// 获取配置的白名单
+	whitelist := config.GetUsageAlertWhitelist()
+
+	// 获取前三天平均用量最低阈值
+	minAvgThreshold := config.GetUsageAlertMinAvgThreshold()
+
+	alerts, err := model.GetGroupUsageAlert(float64(threshold), float64(minAvgThreshold), whitelist)
+	if err != nil {
+		notify.ErrorThrottle(
+			"usageAlertError",
+			time.Minute*5,
+			"check usage alert failed",
+			err.Error(),
+		)
+
+		return
+	}
+
+	if len(alerts) == 0 {
+		return
+	}
+
+	// 计算到明天 0 点的时间,确保每个 group 一天只告警一次
+	now := time.Now()
+	tomorrow := time.Date(now.Year(), now.Month(), now.Day()+1, 0, 0, 0, 0, now.Location())
+	lockDuration := tomorrow.Sub(now)
+
+	// 过滤掉当天已经告警过的 group(通过 trylock 判断)
+	var validAlerts []model.GroupUsageAlertItem
+	for _, alert := range alerts {
+		lockKey := "usageAlert:" + alert.GroupID
+		// 尝试获取锁,如果获取失败说明当天已经告警过
+		if trylock.Lock(lockKey, lockDuration) {
+			validAlerts = append(validAlerts, alert)
+		}
+	}
+
+	if len(validAlerts) == 0 {
+		return
+	}
+
+	message := formatGroupUsageAlerts(validAlerts)
+	notify.Warn(
+		fmt.Sprintf("Detected %d groups with abnormal usage", len(validAlerts)),
+		message,
+	)
+}
+
+// formatGroupUsageAlerts 格式化告警消息
+func formatGroupUsageAlerts(alerts []model.GroupUsageAlertItem) string {
+	if len(alerts) == 0 {
+		return ""
+	}
+
+	var result string
+	for _, alert := range alerts {
+		result += fmt.Sprintf(
+			"GroupID: %s | 3-Day Avg: %.4f | Today: %.4f | Ratio: %.2fx\n",
+			alert.GroupID,
+			alert.ThreeDayAvgAmount,
+			alert.TodayAmount,
+			alert.Ratio,
+		)
+	}
+
+	return result
+}
+
+// CleanLogTask 清理日志任务
+func CleanLogTask(ctx context.Context) {
+	// the interval should not be too large to avoid cleaning too much at once
+	ticker := time.NewTicker(time.Second * 5)
+	defer ticker.Stop()
+
+	for {
+		select {
+		case <-ctx.Done():
+			return
+		case <-ticker.C:
+			if !trylock.Lock("runCleanLog", time.Second*5) {
+				continue
+			}
+
+			optimize := trylock.Lock("runOptimizeLog", time.Hour*24)
+
+			err := model.CleanLog(int(config.GetCleanLogBatchSize()), optimize)
+			if err != nil {
+				notify.ErrorThrottle(
+					"cleanLogError",
+					time.Minute*5,
+					"clean log failed",
+					err.Error(),
+				)
+			}
+		}
+	}
+}