chain.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. package hooks
  2. import (
  3. "fmt"
  4. "sync"
  5. "github.com/QuantumNous/new-api/common"
  6. "github.com/QuantumNous/new-api/core/interfaces"
  7. "github.com/QuantumNous/new-api/core/registry"
  8. )
  9. var (
  10. // 全局Hook链实例(单例)
  11. globalChain *HookChain
  12. globalChainOnce sync.Once
  13. )
  14. // HookChain Hook执行链
  15. type HookChain struct {
  16. hooks []interfaces.RelayHook
  17. mu sync.RWMutex
  18. }
  19. // GetGlobalChain 获取全局Hook链实例
  20. func GetGlobalChain() *HookChain {
  21. globalChainOnce.Do(func() {
  22. globalChain = &HookChain{
  23. hooks: make([]interfaces.RelayHook, 0),
  24. }
  25. // 从注册中心加载hooks
  26. globalChain.LoadHooks()
  27. })
  28. return globalChain
  29. }
  30. // LoadHooks 从注册中心加载hooks
  31. func (c *HookChain) LoadHooks() {
  32. c.mu.Lock()
  33. defer c.mu.Unlock()
  34. c.hooks = registry.ListHooks()
  35. common.SysLog(fmt.Sprintf("Loaded %d enabled hooks", len(c.hooks)))
  36. }
  37. // ReloadHooks 重新加载hooks
  38. func (c *HookChain) ReloadHooks() {
  39. c.LoadHooks()
  40. common.SysLog("Hooks reloaded")
  41. }
  42. // ExecuteBeforeRequest 执行所有BeforeRequest钩子
  43. func (c *HookChain) ExecuteBeforeRequest(ctx *interfaces.HookContext) error {
  44. c.mu.RLock()
  45. hooks := c.hooks
  46. c.mu.RUnlock()
  47. for _, hook := range hooks {
  48. if !hook.Enabled() {
  49. continue
  50. }
  51. if ctx.ShouldSkip {
  52. break
  53. }
  54. if err := hook.OnBeforeRequest(ctx); err != nil {
  55. common.SysError(fmt.Sprintf("Hook %s OnBeforeRequest error: %v", hook.Name(), err))
  56. return fmt.Errorf("hook %s failed: %w", hook.Name(), err)
  57. }
  58. }
  59. return nil
  60. }
  61. // ExecuteAfterResponse 执行所有AfterResponse钩子
  62. func (c *HookChain) ExecuteAfterResponse(ctx *interfaces.HookContext) error {
  63. c.mu.RLock()
  64. hooks := c.hooks
  65. c.mu.RUnlock()
  66. for _, hook := range hooks {
  67. if !hook.Enabled() {
  68. continue
  69. }
  70. if ctx.ShouldSkip {
  71. break
  72. }
  73. if err := hook.OnAfterResponse(ctx); err != nil {
  74. common.SysError(fmt.Sprintf("Hook %s OnAfterResponse error: %v", hook.Name(), err))
  75. return fmt.Errorf("hook %s failed: %w", hook.Name(), err)
  76. }
  77. }
  78. return nil
  79. }
  80. // ExecuteOnError 执行所有OnError钩子
  81. func (c *HookChain) ExecuteOnError(ctx *interfaces.HookContext, err error) error {
  82. c.mu.RLock()
  83. hooks := c.hooks
  84. c.mu.RUnlock()
  85. for _, hook := range hooks {
  86. if !hook.Enabled() {
  87. continue
  88. }
  89. if hookErr := hook.OnError(ctx, err); hookErr != nil {
  90. common.SysError(fmt.Sprintf("Hook %s OnError failed: %v", hook.Name(), hookErr))
  91. // OnError钩子的错误不会中断执行
  92. }
  93. }
  94. return err
  95. }
  96. // GetHooks 获取当前hook列表
  97. func (c *HookChain) GetHooks() []interfaces.RelayHook {
  98. c.mu.RLock()
  99. defer c.mu.RUnlock()
  100. hooks := make([]interfaces.RelayHook, len(c.hooks))
  101. copy(hooks, c.hooks)
  102. return hooks
  103. }
  104. // Count 返回hook数量
  105. func (c *HookChain) Count() int {
  106. c.mu.RLock()
  107. defer c.mu.RUnlock()
  108. return len(c.hooks)
  109. }