task.go 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. package task
  2. import (
  3. "context"
  4. "crypto/sha256"
  5. "encoding/hex"
  6. "fmt"
  7. "slices"
  8. "time"
  9. "github.com/bytedance/sonic"
  10. "github.com/labring/aiproxy/core/common/config"
  11. "github.com/labring/aiproxy/core/common/conv"
  12. "github.com/labring/aiproxy/core/common/ipblack"
  13. "github.com/labring/aiproxy/core/common/notify"
  14. "github.com/labring/aiproxy/core/common/trylock"
  15. "github.com/labring/aiproxy/core/controller"
  16. "github.com/labring/aiproxy/core/model"
  17. )
  18. // AutoTestBannedModelsTask 自动测试被禁用的模型
  19. func AutoTestBannedModelsTask(ctx context.Context) {
  20. ticker := time.NewTicker(time.Second * 30)
  21. defer ticker.Stop()
  22. for {
  23. select {
  24. case <-ctx.Done():
  25. return
  26. case <-ticker.C:
  27. controller.AutoTestBannedModels()
  28. }
  29. }
  30. }
  31. // DetectIPGroupsTask 检测 IP 使用多个 group 的情况
  32. func DetectIPGroupsTask(ctx context.Context) {
  33. ticker := time.NewTicker(time.Minute)
  34. defer ticker.Stop()
  35. for {
  36. select {
  37. case <-ctx.Done():
  38. return
  39. case <-ticker.C:
  40. if !trylock.Lock("runDetectIPGroups", time.Minute) {
  41. continue
  42. }
  43. detectIPGroups()
  44. }
  45. }
  46. }
  47. func detectIPGroups() {
  48. threshold := config.GetIPGroupsThreshold()
  49. if threshold < 1 {
  50. return
  51. }
  52. ipGroupList, err := model.GetIPGroups(int(threshold), time.Now().Add(-time.Hour), time.Now())
  53. if err != nil {
  54. notify.ErrorThrottle("detectIPGroups", time.Minute, "detect IP groups failed", err.Error())
  55. return
  56. }
  57. if len(ipGroupList) == 0 {
  58. return
  59. }
  60. banThreshold := config.GetIPGroupsBanThreshold()
  61. for ip, groups := range ipGroupList {
  62. slices.Sort(groups)
  63. groupsJSON, err := sonic.MarshalString(groups)
  64. if err != nil {
  65. notify.ErrorThrottle(
  66. "detectIPGroupsMarshal",
  67. time.Minute,
  68. "marshal IP groups failed",
  69. err.Error(),
  70. )
  71. continue
  72. }
  73. if banThreshold >= threshold && len(groups) >= int(banThreshold) {
  74. rowsAffected, err := model.UpdateGroupsStatus(groups, model.GroupStatusDisabled)
  75. if err != nil {
  76. notify.ErrorThrottle(
  77. "detectIPGroupsBan",
  78. time.Minute,
  79. "update groups status failed",
  80. err.Error(),
  81. )
  82. }
  83. if rowsAffected > 0 {
  84. notify.Warn(
  85. fmt.Sprintf(
  86. "Suspicious activity: IP %s is using %d groups (exceeds ban threshold of %d). IP and all groups have been disabled.",
  87. ip,
  88. len(groups),
  89. banThreshold,
  90. ),
  91. groupsJSON,
  92. )
  93. ipblack.SetIPBlackAnyWay(ip, time.Hour*48)
  94. }
  95. continue
  96. }
  97. h := sha256.New()
  98. h.Write(conv.StringToBytes(groupsJSON))
  99. groupsHash := hex.EncodeToString(h.Sum(nil))
  100. hashKey := fmt.Sprintf("%s:%s", ip, groupsHash)
  101. notify.WarnThrottle(
  102. hashKey,
  103. time.Hour*3,
  104. fmt.Sprintf(
  105. "Potential abuse: IP %s is using %d groups (exceeds threshold of %d)",
  106. ip,
  107. len(groups),
  108. threshold,
  109. ),
  110. groupsJSON,
  111. )
  112. }
  113. }
  114. // UsageAlertTask 用量异常告警任务
  115. func UsageAlertTask(ctx context.Context) {
  116. ticker := time.NewTicker(time.Hour)
  117. defer ticker.Stop()
  118. for {
  119. select {
  120. case <-ctx.Done():
  121. return
  122. case <-ticker.C:
  123. if !trylock.Lock("runUsageAlert", time.Hour) {
  124. continue
  125. }
  126. checkUsageAlert()
  127. }
  128. }
  129. }
  130. func checkUsageAlert() {
  131. threshold := config.GetUsageAlertThreshold()
  132. if threshold <= 0 {
  133. return
  134. }
  135. // 获取配置的白名单
  136. whitelist := config.GetUsageAlertWhitelist()
  137. // 获取前三天平均用量最低阈值
  138. minAvgThreshold := config.GetUsageAlertMinAvgThreshold()
  139. alerts, err := model.GetGroupUsageAlert(float64(threshold), float64(minAvgThreshold), whitelist)
  140. if err != nil {
  141. notify.ErrorThrottle(
  142. "usageAlertError",
  143. time.Minute*5,
  144. "check usage alert failed",
  145. err.Error(),
  146. )
  147. return
  148. }
  149. if len(alerts) == 0 {
  150. return
  151. }
  152. // 计算到明天 0 点的时间,确保每个 group 一天只告警一次
  153. now := time.Now()
  154. tomorrow := time.Date(now.Year(), now.Month(), now.Day()+1, 0, 0, 0, 0, now.Location())
  155. lockDuration := tomorrow.Sub(now)
  156. // 过滤掉当天已经告警过的 group(通过 trylock 判断)
  157. var validAlerts []model.GroupUsageAlertItem
  158. for _, alert := range alerts {
  159. lockKey := "usageAlert:" + alert.GroupID
  160. // 尝试获取锁,如果获取失败说明当天已经告警过
  161. if trylock.Lock(lockKey, lockDuration) {
  162. validAlerts = append(validAlerts, alert)
  163. }
  164. }
  165. if len(validAlerts) == 0 {
  166. return
  167. }
  168. message := formatGroupUsageAlerts(validAlerts)
  169. notify.Warn(
  170. fmt.Sprintf("Detected %d groups with abnormal usage", len(validAlerts)),
  171. message,
  172. )
  173. }
  174. // formatGroupUsageAlerts 格式化告警消息
  175. func formatGroupUsageAlerts(alerts []model.GroupUsageAlertItem) string {
  176. if len(alerts) == 0 {
  177. return ""
  178. }
  179. var result string
  180. for _, alert := range alerts {
  181. result += fmt.Sprintf(
  182. "GroupID: %s | 3-Day Avg: %.4f | Today: %.4f | Ratio: %.2fx\n",
  183. alert.GroupID,
  184. alert.ThreeDayAvgAmount,
  185. alert.TodayAmount,
  186. alert.Ratio,
  187. )
  188. }
  189. return result
  190. }
  191. // CleanLogTask 清理日志任务
  192. func CleanLogTask(ctx context.Context) {
  193. // the interval should not be too large to avoid cleaning too much at once
  194. ticker := time.NewTicker(time.Second * 5)
  195. defer ticker.Stop()
  196. for {
  197. select {
  198. case <-ctx.Done():
  199. return
  200. case <-ticker.C:
  201. if !trylock.Lock("runCleanLog", time.Second*5) {
  202. continue
  203. }
  204. optimize := trylock.Lock("runOptimizeLog", time.Hour*24)
  205. err := model.CleanLog(int(config.GetCleanLogBatchSize()), optimize)
  206. if err != nil {
  207. notify.ErrorThrottle(
  208. "cleanLogError",
  209. time.Minute*5,
  210. "clean log failed",
  211. err.Error(),
  212. )
  213. }
  214. }
  215. }
  216. }