sensitive.go 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. package service
  2. import (
  3. "errors"
  4. "fmt"
  5. "one-api/common"
  6. "one-api/constant"
  7. "one-api/dto"
  8. "strings"
  9. )
  10. func CheckSensitiveMessages(messages []dto.Message) error {
  11. for _, message := range messages {
  12. if len(message.Content) > 0 {
  13. if message.IsStringContent() {
  14. stringContent := message.StringContent()
  15. if ok, words := SensitiveWordContains(stringContent); ok {
  16. return errors.New("sensitive words: " + strings.Join(words, ","))
  17. }
  18. }
  19. } else {
  20. arrayContent := message.ParseContent()
  21. for _, m := range arrayContent {
  22. if m.Type == "image_url" {
  23. // TODO: check image url
  24. } else {
  25. if ok, words := SensitiveWordContains(m.Text); ok {
  26. return errors.New("sensitive words: " + strings.Join(words, ","))
  27. }
  28. }
  29. }
  30. }
  31. }
  32. return nil
  33. }
  34. func CheckSensitiveText(text string) error {
  35. if ok, words := SensitiveWordContains(text); ok {
  36. return errors.New("sensitive words: " + strings.Join(words, ","))
  37. }
  38. return nil
  39. }
  40. func CheckSensitiveInput(input any) error {
  41. switch v := input.(type) {
  42. case string:
  43. return CheckSensitiveText(v)
  44. case []string:
  45. text := ""
  46. for _, s := range v {
  47. text += s
  48. }
  49. return CheckSensitiveText(text)
  50. }
  51. return CheckSensitiveText(fmt.Sprintf("%v", input))
  52. }
  53. // SensitiveWordContains 是否包含敏感词,返回是否包含敏感词和敏感词列表
  54. func SensitiveWordContains(text string) (bool, []string) {
  55. if len(constant.SensitiveWords) == 0 {
  56. return false, nil
  57. }
  58. checkText := strings.ToLower(text)
  59. // 构建一个AC自动机
  60. m := common.InitAc()
  61. hits := m.MultiPatternSearch([]rune(checkText), false)
  62. if len(hits) > 0 {
  63. words := make([]string, 0)
  64. for _, hit := range hits {
  65. words = append(words, string(hit.Word))
  66. }
  67. return true, words
  68. }
  69. return false, nil
  70. }
  71. // SensitiveWordReplace 敏感词替换,返回是否包含敏感词和替换后的文本
  72. func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string, string) {
  73. if len(constant.SensitiveWords) == 0 {
  74. return false, nil, text
  75. }
  76. checkText := strings.ToLower(text)
  77. m := common.InitAc()
  78. hits := m.MultiPatternSearch([]rune(checkText), returnImmediately)
  79. if len(hits) > 0 {
  80. words := make([]string, 0)
  81. for _, hit := range hits {
  82. pos := hit.Pos
  83. word := string(hit.Word)
  84. text = text[:pos] + "**###**" + text[pos+len(word):]
  85. words = append(words, word)
  86. }
  87. return true, words, text
  88. }
  89. return false, nil, text
  90. }