task.go 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. package task
  2. import (
  3. "context"
  4. "crypto/sha256"
  5. "encoding/hex"
  6. "fmt"
  7. "slices"
  8. "strings"
  9. "time"
  10. "github.com/bytedance/sonic"
  11. "github.com/labring/aiproxy/core/common/config"
  12. "github.com/labring/aiproxy/core/common/conv"
  13. "github.com/labring/aiproxy/core/common/ipblack"
  14. "github.com/labring/aiproxy/core/common/notify"
  15. "github.com/labring/aiproxy/core/common/trylock"
  16. "github.com/labring/aiproxy/core/controller"
  17. "github.com/labring/aiproxy/core/model"
  18. )
  19. // AutoTestBannedModelsTask 自动测试被禁用的模型
  20. func AutoTestBannedModelsTask(ctx context.Context) {
  21. ticker := time.NewTicker(time.Second * 30)
  22. defer ticker.Stop()
  23. for {
  24. select {
  25. case <-ctx.Done():
  26. return
  27. case <-ticker.C:
  28. controller.AutoTestBannedModels()
  29. }
  30. }
  31. }
  32. // DetectIPGroupsTask 检测 IP 使用多个 group 的情况
  33. func DetectIPGroupsTask(ctx context.Context) {
  34. ticker := time.NewTicker(time.Minute)
  35. defer ticker.Stop()
  36. for {
  37. select {
  38. case <-ctx.Done():
  39. return
  40. case <-ticker.C:
  41. if !trylock.Lock("runDetectIPGroups", time.Minute) {
  42. continue
  43. }
  44. detectIPGroups()
  45. }
  46. }
  47. }
  48. func detectIPGroups() {
  49. threshold := config.GetIPGroupsThreshold()
  50. if threshold < 1 {
  51. return
  52. }
  53. ipGroupList, err := model.GetIPGroups(int(threshold), time.Now().Add(-time.Hour), time.Now())
  54. if err != nil {
  55. notify.ErrorThrottle("detectIPGroups", time.Minute, "detect IP groups failed", err.Error())
  56. return
  57. }
  58. if len(ipGroupList) == 0 {
  59. return
  60. }
  61. banThreshold := config.GetIPGroupsBanThreshold()
  62. for ip, groups := range ipGroupList {
  63. slices.Sort(groups)
  64. groupsJSON, err := sonic.MarshalString(groups)
  65. if err != nil {
  66. notify.ErrorThrottle(
  67. "detectIPGroupsMarshal",
  68. time.Minute,
  69. "marshal IP groups failed",
  70. err.Error(),
  71. )
  72. continue
  73. }
  74. if banThreshold >= threshold && len(groups) >= int(banThreshold) {
  75. rowsAffected, err := model.UpdateGroupsStatus(groups, model.GroupStatusDisabled)
  76. if err != nil {
  77. notify.ErrorThrottle(
  78. "detectIPGroupsBan",
  79. time.Minute,
  80. "update groups status failed",
  81. err.Error(),
  82. )
  83. }
  84. if rowsAffected > 0 {
  85. notify.Warn(
  86. fmt.Sprintf(
  87. "Suspicious activity: IP %s is using %d groups (exceeds ban threshold of %d). IP and all groups have been disabled.",
  88. ip,
  89. len(groups),
  90. banThreshold,
  91. ),
  92. groupsJSON,
  93. )
  94. ipblack.SetIPBlackAnyWay(ip, time.Hour*48)
  95. }
  96. continue
  97. }
  98. h := sha256.New()
  99. h.Write(conv.StringToBytes(groupsJSON))
  100. groupsHash := hex.EncodeToString(h.Sum(nil))
  101. hashKey := fmt.Sprintf("%s:%s", ip, groupsHash)
  102. notify.WarnThrottle(
  103. hashKey,
  104. time.Hour*3,
  105. fmt.Sprintf(
  106. "Potential abuse: IP %s is using %d groups (exceeds threshold of %d)",
  107. ip,
  108. len(groups),
  109. threshold,
  110. ),
  111. groupsJSON,
  112. )
  113. }
  114. }
  115. // UsageAlertTask 用量异常告警任务
  116. func UsageAlertTask(ctx context.Context) {
  117. ticker := time.NewTicker(time.Hour)
  118. defer ticker.Stop()
  119. for {
  120. select {
  121. case <-ctx.Done():
  122. return
  123. case <-ticker.C:
  124. if !trylock.Lock("runUsageAlert", time.Hour) {
  125. continue
  126. }
  127. checkUsageAlert()
  128. }
  129. }
  130. }
  131. func checkUsageAlert() {
  132. threshold := config.GetUsageAlertThreshold()
  133. if threshold <= 0 {
  134. return
  135. }
  136. // 获取配置的白名单
  137. whitelist := config.GetUsageAlertWhitelist()
  138. // 获取前三天平均用量最低阈值
  139. minAvgThreshold := config.GetUsageAlertMinAvgThreshold()
  140. alerts, err := model.GetGroupUsageAlert(float64(threshold), float64(minAvgThreshold), whitelist)
  141. if err != nil {
  142. notify.ErrorThrottle(
  143. "usageAlertError",
  144. time.Minute*5,
  145. "check usage alert failed",
  146. err.Error(),
  147. )
  148. return
  149. }
  150. if len(alerts) == 0 {
  151. return
  152. }
  153. // 计算到明天 0 点的时间,确保每个 group 一天只告警一次
  154. now := time.Now()
  155. tomorrow := time.Date(now.Year(), now.Month(), now.Day()+1, 0, 0, 0, 0, now.Location())
  156. lockDuration := tomorrow.Sub(now)
  157. // 过滤掉当天已经告警过的 group(通过 trylock 判断)
  158. var validAlerts []model.GroupUsageAlertItem
  159. for _, alert := range alerts {
  160. lockKey := "usageAlert:" + alert.GroupID
  161. // 尝试获取锁,如果获取失败说明当天已经告警过
  162. if trylock.Lock(lockKey, lockDuration) {
  163. validAlerts = append(validAlerts, alert)
  164. }
  165. }
  166. if len(validAlerts) == 0 {
  167. return
  168. }
  169. message := formatGroupUsageAlerts(validAlerts)
  170. notify.Warn(
  171. fmt.Sprintf("Detected %d groups with abnormal usage", len(validAlerts)),
  172. message,
  173. )
  174. }
  175. // formatGroupUsageAlerts 格式化告警消息
  176. func formatGroupUsageAlerts(alerts []model.GroupUsageAlertItem) string {
  177. if len(alerts) == 0 {
  178. return ""
  179. }
  180. var result strings.Builder
  181. for _, alert := range alerts {
  182. result.WriteString(fmt.Sprintf(
  183. "GroupID: %s | 3-Day Avg: %.4f | Today: %.4f | Ratio: %.2fx\n",
  184. alert.GroupID,
  185. alert.ThreeDayAvgAmount,
  186. alert.TodayAmount,
  187. alert.Ratio,
  188. ))
  189. }
  190. return result.String()
  191. }
  192. // CleanLogTask 清理日志任务
  193. func CleanLogTask(ctx context.Context) {
  194. // the interval should not be too large to avoid cleaning too much at once
  195. ticker := time.NewTicker(time.Second * 5)
  196. defer ticker.Stop()
  197. for {
  198. select {
  199. case <-ctx.Done():
  200. return
  201. case <-ticker.C:
  202. if !trylock.Lock("runCleanLog", time.Second*3) {
  203. continue
  204. }
  205. optimize := trylock.Lock("runOptimizeLog", time.Hour*24)
  206. err := model.CleanLog(int(config.GetCleanLogBatchSize()), optimize)
  207. if err != nil {
  208. notify.ErrorThrottle(
  209. "cleanLogError",
  210. time.Minute*5,
  211. "clean log failed",
  212. err.Error(),
  213. )
  214. }
  215. }
  216. }
  217. }